You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by re...@apache.org on 2023/07/22 17:40:51 UTC
[beam] branch master updated: Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform
This is an automated email from the ASF dual-hosted git repository.
reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 05305ede453 Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform
05305ede453 is described below
commit 05305ede45366f158f27fc2b83b9ce00db4df2ab
Author: Reuven Lax <re...@google.com>
AuthorDate: Sat Jul 22 10:40:43 2023 -0700
Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform
---
.../apache/beam/sdk/schemas/transforms/Group.java | 162 ++++++++++++++++-----
.../beam/sdk/schemas/transforms/GroupTest.java | 96 +++++++++---
.../sdk/extensions/sql/impl/rel/BeamWindowRel.java | 5 +-
3 files changed, 201 insertions(+), 62 deletions(-)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
index fb48ceed311..fe8933d24d5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
@@ -17,7 +17,9 @@
*/
package org.apache.beam.sdk.schemas.transforms;
+import com.google.auto.value.AutoOneOf;
import com.google.auto.value.AutoValue;
+import java.io.Serializable;
import java.util.List;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
@@ -34,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
@@ -42,6 +45,7 @@ import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.Nullable;
/**
* A generic grouping transform for schema {@link PCollection}s.
@@ -153,7 +157,7 @@ public class Group {
*/
public <OutputT> CombineGlobally<InputT, OutputT> aggregate(
CombineFn<InputT, ?, OutputT> combineFn) {
- return new CombineGlobally<>(combineFn);
+ return new CombineGlobally<>(combineFn, 0);
}
/**
@@ -169,10 +173,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName),
- false,
- fn,
- outputFieldName));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName),
+ 0);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -183,7 +185,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName),
+ 0);
}
/** The same as {@link #aggregateField} but using field id. */
@@ -194,7 +197,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName),
+ 0);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -205,7 +209,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName),
+ 0);
}
/**
@@ -221,7 +226,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField),
+ 0);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -232,7 +238,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField),
+ 0);
}
/** The same as {@link #aggregateField} but using field id. */
@@ -241,7 +248,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField));
+ FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField),
+ 0);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -252,7 +260,8 @@ public class Group {
return new CombineFieldsGlobally<>(
SchemaAggregateFn.create()
.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField));
+ FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField),
+ 0);
}
/**
@@ -298,8 +307,8 @@ public class Group {
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
String outputFieldName) {
return new CombineFieldsGlobally<>(
- SchemaAggregateFn.create()
- .aggregateFields(fieldsToAggregate, false, fn, outputFieldName));
+ SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputFieldName),
+ 0);
}
/**
@@ -335,7 +344,7 @@ public class Group {
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
return new CombineFieldsGlobally<>(
- SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField));
+ SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField), 0);
}
@Override
@@ -351,14 +360,20 @@ public class Group {
public static class CombineGlobally<InputT, OutputT>
extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
final CombineFn<InputT, ?, OutputT> combineFn;
+ int fanout;
- CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn) {
+ CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn, int fanout) {
this.combineFn = combineFn;
+ this.fanout = fanout;
+ }
+
+ public CombineGlobally<InputT, OutputT> withFanout(int fanout) {
+ return new CombineGlobally<>(combineFn, fanout);
}
@Override
public PCollection<OutputT> expand(PCollection<InputT> input) {
- return input.apply("globalCombine", Combine.globally(combineFn));
+ return input.apply("globalCombine", Combine.globally(combineFn).withFanout(fanout));
}
}
@@ -420,9 +435,11 @@ public class Group {
*/
public static class CombineFieldsGlobally<InputT> extends AggregateCombiner<InputT> {
private final SchemaAggregateFn.Inner schemaAggregateFn;
+ private final int fanout;
- CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn) {
+ CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn, int fanout) {
this.schemaAggregateFn = schemaAggregateFn;
+ this.fanout = fanout;
}
/**
@@ -431,7 +448,7 @@ public class Group {
* determined by the output types of all the composed combiners.
*/
public static CombineFieldsGlobally<?> create() {
- return new CombineFieldsGlobally<>(SchemaAggregateFn.create());
+ return new CombineFieldsGlobally<>(SchemaAggregateFn.create(), 0);
}
/**
@@ -450,7 +467,8 @@ public class Group {
String outputFieldName) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName),
+ fanout);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -460,7 +478,8 @@ public class Group {
String outputFieldName) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName),
+ fanout);
}
public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField(
@@ -469,7 +488,8 @@ public class Group {
String outputFieldName) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName),
+ fanout);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -479,7 +499,8 @@ public class Group {
String outputFieldName) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName));
+ FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName),
+ fanout);
}
/**
@@ -495,7 +516,8 @@ public class Group {
Field outputField) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField),
+ fanout);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -505,7 +527,8 @@ public class Group {
Field outputField) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField));
+ FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField),
+ fanout);
}
@Override
@@ -513,7 +536,8 @@ public class Group {
int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField));
+ FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField),
+ fanout);
}
public <CombineInputT, AccumT, CombineOutputT>
@@ -523,7 +547,8 @@ public class Group {
Field outputField) {
return new CombineFieldsGlobally<>(
schemaAggregateFn.aggregateFields(
- FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField));
+ FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField),
+ fanout);
}
/**
@@ -568,7 +593,8 @@ public class Group {
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
String outputFieldName) {
return new CombineFieldsGlobally<>(
- schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName));
+ schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName),
+ fanout);
}
/**
@@ -605,13 +631,17 @@ public class Group {
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
return new CombineFieldsGlobally<>(
- schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField));
+ schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField), fanout);
+ }
+
+ public CombineFieldsGlobally<InputT> withFanout(int fanout) {
+ return new CombineFieldsGlobally<>(schemaAggregateFn, fanout);
}
@Override
public PCollection<Row> expand(PCollection<InputT> input) {
SchemaAggregateFn.Inner fn = schemaAggregateFn.withSchema(input.getSchema());
- Combine.Globally<Row, Row> combineFn = Combine.globally(fn);
+ Combine.Globally<Row, Row> combineFn = Combine.globally(fn).withFanout(fanout);
if (!(input.getWindowingStrategy().getWindowFn() instanceof GlobalWindows)) {
combineFn = combineFn.withoutDefaults();
}
@@ -631,6 +661,7 @@ public class Group {
*/
@AutoValue
public abstract static class ByFields<InputT> extends AggregateCombiner<InputT> {
+
abstract FieldAccessDescriptor getFieldAccessDescriptor();
abstract String getKeyField();
@@ -651,11 +682,11 @@ public class Group {
abstract ByFields<InputT> build();
}
- class ToKv extends PTransform<PCollection<InputT>, PCollection<KV<Row, Iterable<Row>>>> {
+ class ToKV extends PTransform<PCollection<InputT>, PCollection<KV<Row, Row>>> {
private RowSelector rowSelector;
@Override
- public PCollection<KV<Row, Iterable<Row>>> expand(PCollection<InputT> input) {
+ public PCollection<KV<Row, Row>> expand(PCollection<InputT> input) {
Schema schema = input.getSchema();
FieldAccessDescriptor resolved = getFieldAccessDescriptor().resolve(schema);
rowSelector = new RowSelectorContainer(schema, resolved, true);
@@ -666,13 +697,12 @@ public class Group {
.apply(
"selectKeys",
WithKeys.of((Row e) -> rowSelector.select(e)).withKeyType(TypeDescriptors.rows()))
- .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)))
- .apply("GroupByKey", GroupByKey.create());
+ .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
}
}
- public ToKv getToKvs() {
- return new ToKv();
+ public ToKV getToKV() {
+ return new ToKV();
}
private static <InputT> ByFields<InputT> of(FieldAccessDescriptor fieldAccessDescriptor) {
@@ -919,7 +949,8 @@ public class Group {
.build();
return input
- .apply("ToKvs", getToKvs())
+ .apply("ToKvs", getToKV())
+ .apply("GroupByKey", GroupByKey.create())
.apply(
"ToRow",
ParDo.of(
@@ -942,6 +973,31 @@ public class Group {
*/
@AutoValue
public abstract static class CombineFieldsByFields<InputT> extends AggregateCombiner<InputT> {
+
+ @AutoOneOf(Fanout.Kind.class)
+ public abstract static class Fanout implements Serializable {
+ public enum Kind {
+ NUMBER,
+ FUNCTION
+ }
+
+ public abstract Kind getKind();
+
+ public abstract Integer getNumber();
+
+ public abstract SerializableFunction<Row, Integer> getFunction();
+
+ public static Fanout of(int n) {
+ return AutoOneOf_Group_CombineFieldsByFields_Fanout.number(n);
+ }
+
+ public static Fanout of(SerializableFunction<Row, Integer> f) {
+ return AutoOneOf_Group_CombineFieldsByFields_Fanout.function(f);
+ }
+ }
+
+ abstract @Nullable Fanout getFanout();
+
abstract ByFields<InputT> getByFields();
abstract SchemaAggregateFn.Inner getSchemaAggregateFn();
@@ -954,6 +1010,8 @@ public class Group {
@AutoValue.Builder
abstract static class Builder<InputT> {
+ public abstract Builder<InputT> setFanout(@Nullable Fanout value);
+
abstract Builder<InputT> setByFields(ByFields<InputT> byFields);
abstract Builder<InputT> setSchemaAggregateFn(SchemaAggregateFn.Inner schemaAggregateFn);
@@ -988,6 +1046,14 @@ public class Group {
return toBuilder().setValueField(valueField).build();
}
+ public CombineFieldsByFields<InputT> withHotKeyFanout(int n) {
+ return toBuilder().setFanout(Fanout.of(n)).build();
+ }
+
+ public CombineFieldsByFields<InputT> withHotKeyFanout(SerializableFunction<Row, Integer> f) {
+ return toBuilder().setFanout(Fanout.of(f)).build();
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -1187,9 +1253,25 @@ public class Group {
.build();
}
+ PTransform<PCollection<KV<Row, Row>>, PCollection<KV<Row, Row>>> getCombineTransform(
+ Schema schema) {
+ SchemaAggregateFn.Inner fn = getSchemaAggregateFn().withSchema(schema);
+ @Nullable Fanout fanout = getFanout();
+ if (fanout != null) {
+ switch (fanout.getKind()) {
+ case NUMBER:
+ return Combine.<Row, Row, Row>perKey(fn).withHotKeyFanout(fanout.getNumber());
+ case FUNCTION:
+ return Combine.<Row, Row, Row>perKey(fn).withHotKeyFanout(fanout.getFunction());
+ default:
+ throw new RuntimeException("Unexpected kind: " + fanout.getKind());
+ }
+ }
+ return Combine.perKey(fn);
+ }
+
@Override
public PCollection<Row> expand(PCollection<InputT> input) {
- SchemaAggregateFn.Inner fn = getSchemaAggregateFn().withSchema(input.getSchema());
Schema keySchema = getByFields().getKeySchema(input.getSchema());
Schema outputSchema =
@@ -1199,8 +1281,8 @@ public class Group {
.build();
return input
- .apply("ToKvs", getByFields().getToKvs())
- .apply("Combine", Combine.groupedValues(fn))
+ .apply("ToKvs", getByFields().getToKV())
+ .apply("Combine", getCombineTransform(input.getSchema()))
.apply(
"ToRow",
ParDo.of(
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
index b4b074e0ca0..f6f33208a10 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
@@ -50,6 +51,7 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sample;
import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.values.PCollection;
@@ -260,17 +262,30 @@ public class GroupTest implements Serializable {
@Test
@Category(NeedsRunner.class)
- public void testGlobalAggregation() {
+ public void testGlobalAggregationWithoutFanout() {
+ globalAggregationWithFanout(false);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testGlobalAggregationWithFanout() {
+ globalAggregationWithFanout(true);
+ }
+
+ public void globalAggregationWithFanout(boolean withFanout) {
Collection<Basic> elements =
ImmutableList.of(
Basic.of("key1", 1, "value1"),
Basic.of("key1", 1, "value2"),
Basic.of("key2", 2, "value3"),
Basic.of("key2", 2, "value4"));
- PCollection<Long> count =
- pipeline
- .apply(Create.of(elements))
- .apply(Group.<Basic>globally().aggregate(Count.combineFn()));
+
+ Group.CombineGlobally<Basic, Long> transform =
+ Group.<Basic>globally().aggregate(Count.combineFn());
+ if (withFanout) {
+ transform = transform.withFanout(10);
+ }
+ PCollection<Long> count = pipeline.apply(Create.of(elements)).apply(transform);
PAssert.that(count).containsInAnyOrder(4L);
pipeline.run();
@@ -426,7 +441,17 @@ public class GroupTest implements Serializable {
@Test
@Category(NeedsRunner.class)
- public void testAggregateByMultipleFields() {
+ public void testAggregateByMultipleFieldsWithoutFanout() {
+ aggregateByMultipleFieldsWithFanout(false);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testAggregateByMultipleFieldsWithFanout() {
+ aggregateByMultipleFieldsWithFanout(true);
+ }
+
+ public void aggregateByMultipleFieldsWithFanout(boolean withFanout) {
Collection<Aggregate> elements =
ImmutableList.of(
Aggregate.of(1, 1, 2),
@@ -435,12 +460,14 @@ public class GroupTest implements Serializable {
Aggregate.of(4, 2, 5));
List<String> fieldNames = Lists.newArrayList("field1", "field2");
- PCollection<Row> aggregate =
- pipeline
- .apply(Create.of(elements))
- .apply(
- Group.<Aggregate>globally()
- .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2"));
+
+ Group.CombineFieldsGlobally<Aggregate> transform =
+ Group.<Aggregate>globally()
+ .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2");
+ if (withFanout) {
+ transform = transform.withFanout(10);
+ }
+ PCollection<Row> aggregate = pipeline.apply(Create.of(elements)).apply(transform);
Schema outputSchema = Schema.builder().addInt64Field("field1+field2").build();
Row expectedRow = Row.withSchema(outputSchema).addValues(16L).build();
@@ -462,7 +489,25 @@ public class GroupTest implements Serializable {
@Test
@Category(NeedsRunner.class)
- public void testByKeyWithSchemaAggregateFnNestedFields() {
+ public void testByKeyWithSchemaAggregateFnNestedFieldsNoFanout() {
+ byKeyWithSchemaAggregateFnNestedFieldsWithFanout(null);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testByKeyWithSchemaAggregateFnNestedFieldsWithNumberFanout() {
+ byKeyWithSchemaAggregateFnNestedFieldsWithFanout(Group.CombineFieldsByFields.Fanout.of(10));
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testByKeyWithSchemaAggregateFnNestedFieldsWithFunctionFanout() {
+ byKeyWithSchemaAggregateFnNestedFieldsWithFanout(
+ Group.CombineFieldsByFields.Fanout.of(SerializableFunctions.constant(10)));
+ }
+
+ public void byKeyWithSchemaAggregateFnNestedFieldsWithFanout(
+ @Nullable Group.CombineFieldsByFields.Fanout fanout) {
Collection<OuterAggregate> elements =
ImmutableList.of(
OuterAggregate.of(Aggregate.of(1, 1, 2)),
@@ -470,14 +515,23 @@ public class GroupTest implements Serializable {
OuterAggregate.of(Aggregate.of(3, 2, 4)),
OuterAggregate.of(Aggregate.of(4, 2, 5)));
- PCollection<Row> aggregations =
- pipeline
- .apply(Create.of(elements))
- .apply(
- Group.<OuterAggregate>byFieldNames("inner.field2")
- .aggregateField("inner.field1", Sum.ofLongs(), "field1_sum")
- .aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum")
- .aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top"));
+ Group.CombineFieldsByFields<OuterAggregate> transform =
+ Group.<OuterAggregate>byFieldNames("inner.field2")
+ .aggregateField("inner.field1", Sum.ofLongs(), "field1_sum")
+ .aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum")
+ .aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top");
+ if (fanout != null) {
+ switch (fanout.getKind()) {
+ case NUMBER:
+ transform = transform.withHotKeyFanout(fanout.getNumber());
+ break;
+ case FUNCTION:
+ transform = transform.withHotKeyFanout(fanout.getFunction());
+ break;
+ }
+ }
+
+ PCollection<Row> aggregations = pipeline.apply(Create.of(elements)).apply(transform);
Schema keySchema = Schema.builder().addInt64Field("field2").build();
Schema valueSchema =
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
index 906258b164a..be88229e755 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
@@ -35,6 +35,7 @@ import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
@@ -254,7 +255,9 @@ public class BeamWindowRel extends Window implements BeamRelNode {
org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> myg =
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(af.partitionKeys);
PCollection<KV<Row, Iterable<Row>>> partitionBy =
- inputData.apply(prefix + "partitionBy", myg.getToKvs());
+ inputData
+ .apply(prefix + "partitionByKV", myg.getToKV())
+ .apply(prefix + "partitionByGK", GroupByKey.create());
partitioned =
partitionBy
.apply(prefix + "selectOnlyValues", ParDo.of(new SelectOnlyValues()))