You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/05/10 00:05:31 UTC
[beam] branch master updated: [BEAM-14347] Add generic registration for accumulators (#17579)
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new a5fbc8ed163 [BEAM-14347] Add generic registration for accumulators (#17579)
a5fbc8ed163 is described below
commit a5fbc8ed163339879547a8e16e4f4f99a5b3a388
Author: Danny McCormick <da...@google.com>
AuthorDate: Mon May 9 20:05:24 2022 -0400
[BEAM-14347] Add generic registration for accumulators (#17579)
---
sdks/go/pkg/beam/registration/registration.go | 655 ++++++++++++++++++++-
sdks/go/pkg/beam/registration/registration.tmpl | 310 +++++++++-
sdks/go/pkg/beam/registration/registration_test.go | 234 ++++++++
3 files changed, 1197 insertions(+), 2 deletions(-)
diff --git a/sdks/go/pkg/beam/registration/registration.go b/sdks/go/pkg/beam/registration/registration.go
index ba1585d71bf..95057a428ae 100644
--- a/sdks/go/pkg/beam/registration/registration.go
+++ b/sdks/go/pkg/beam/registration/registration.go
@@ -21,6 +21,7 @@ package registration
import (
"context"
+ "fmt"
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
@@ -7030,13 +7031,665 @@ type teardown1x1 interface {
Teardown(ctx context.Context) error
}
+type createAccumulator0x1[T any] interface {
+ CreateAccumulator() T
+}
+
+type createAccumulator0x2[T any] interface {
+ CreateAccumulator() (T, error)
+}
+
+type addInput2x1[T1, T2 any] interface {
+ AddInput(a T1, i T2) T1
+}
+
+type addInput2x2[T1, T2 any] interface {
+ AddInput(a T1, i T2) (T1, error)
+}
+
+type mergeAccumulators2x1[T any] interface {
+ MergeAccumulators(a0 T, a1 T) T
+}
+
+type mergeAccumulators2x2[T any] interface {
+ MergeAccumulators(a0 T, a1 T) (T, error)
+}
+
+type extractOutput1x1[T1, T2 any] interface {
+ ExtractOutput(a T1) T2
+}
+
+type extractOutput1x2[T1, T2 any] interface {
+ ExtractOutput(a T1) (T2, error)
+}
+
+// Combiner1 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+// Combiner1 should be used when your accumulator, input, and output are all of the same type.
+// It can be called with register.Combiner1[T](&CustomCombiner{})
+// where T is the type of the input/accumulator/output.
+func Combiner1[T0 any](accum interface{}) {
+ registerCombinerTypes(accum)
+ accumVal := reflect.ValueOf(accum)
+ var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ }
+
+ if mergeAccumulatorsWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+ }
+
+ var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(createAccumulator0x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() (T0, error))
+ return &caller0x2[T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() (T0, error) {
+ return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+ })
+ }
+ } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() T0)
+ return &caller0x1[T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() T0 {
+ return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+ })
+ }
+ }
+ if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+ }
+
+ var addInputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(addInput2x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ }
+
+ if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+ }
+
+ var extractOutputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T0, error))
+ return &caller1x2[T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+ return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T0)
+ return &caller1x1[T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T0 {
+ return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ }
+
+ if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+ }
+
+ wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+ m := map[string]reflectx.Func{}
+ if mergeAccumulatorsWrapper != nil {
+ m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+ }
+ if createAccumulatorWrapper != nil {
+ m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+ }
+ if addInputWrapper != nil {
+ m["AddInput"] = addInputWrapper(fn)
+ }
+ if extractOutputWrapper != nil {
+ m["ExtractOutput"] = extractOutputWrapper(fn)
+ }
+
+ return m
+ }
+ reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+// Combiner2 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+// Combiner2 should be used when your accumulator, input, and output are 2 distinct types.
+// It can be called with register.Combiner2[T1, T2](&CustomCombiner{})
+// where T1 is the type of the accumulator and T2 is the other type.
+func Combiner2[T0, T1 any](accum interface{}) {
+ registerCombinerTypes(accum)
+ accumVal := reflect.ValueOf(accum)
+ var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ }
+
+ if mergeAccumulatorsWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+ }
+
+ var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(createAccumulator0x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() (T0, error))
+ return &caller0x2[T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() (T0, error) {
+ return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+ })
+ }
+ } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() T0)
+ return &caller0x1[T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() T0 {
+ return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+ })
+ }
+ }
+ if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+ }
+
+ var addInputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(addInput2x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T1) (T0, error))
+ return &caller2x2[T0, T1, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) {
+ return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T1) T0)
+ return &caller2x1[T0, T1, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+ return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+ })
+ }
+ }
+
+ if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+ }
+
+ var extractOutputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T0, error))
+ return &caller1x2[T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+ return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T0)
+ return &caller1x1[T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T0 {
+ return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T1, error))
+ return &caller1x2[T0, T1, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+ return fn.(extractOutput1x2[T0, T1]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T1)
+ return &caller1x1[T0, T1]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T1)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T1 {
+ return fn.(extractOutput1x1[T0, T1]).ExtractOutput(a0)
+ })
+ }
+ }
+
+ if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+ }
+
+ wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+ m := map[string]reflectx.Func{}
+ if mergeAccumulatorsWrapper != nil {
+ m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+ }
+ if createAccumulatorWrapper != nil {
+ m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+ }
+ if addInputWrapper != nil {
+ m["AddInput"] = addInputWrapper(fn)
+ }
+ if extractOutputWrapper != nil {
+ m["ExtractOutput"] = extractOutputWrapper(fn)
+ }
+
+ return m
+ }
+ reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+// Combiner3 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+// Combiner3 should be used when your accumulator, input, and output are 3 distinct types.
+// It can be called with register.Combiner3[T1, T2, T3](&CustomCombiner{})
+// where T1 is the type of the accumulator, T2 is the type of the input, and T3 is the type of the output.
+func Combiner3[T0, T1, T2 any](accum interface{}) {
+ registerCombinerTypes(accum)
+ accumVal := reflect.ValueOf(accum)
+ var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ }
+
+ if mergeAccumulatorsWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+ }
+
+ var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(createAccumulator0x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() (T0, error))
+ return &caller0x2[T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() (T0, error) {
+ return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+ })
+ }
+ } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() T0)
+ return &caller0x1[T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() T0 {
+ return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+ })
+ }
+ }
+ if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+ }
+
+ var addInputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(addInput2x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T1) (T0, error))
+ return &caller2x2[T0, T1, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) {
+ return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T1) T0)
+ return &caller2x1[T0, T1, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+ return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x2[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T2) (T0, error))
+ return &caller2x2[T0, T2, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T2) (T0, error) {
+ return fn.(addInput2x2[T0, T2]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T2) T0)
+ return &caller2x1[T0, T2, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T2) T0 {
+ return fn.(addInput2x1[T0, T2]).AddInput(a0, a1)
+ })
+ }
+ }
+
+ if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+ }
+
+ var extractOutputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T0, error))
+ return &caller1x2[T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+ return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T0)
+ return &caller1x1[T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T0 {
+ return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T1, error))
+ return &caller1x2[T0, T1, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+ return fn.(extractOutput1x2[T0, T1]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T1)
+ return &caller1x1[T0, T1]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T1)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T1 {
+ return fn.(extractOutput1x1[T0, T1]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x2[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T2, error))
+ return &caller1x2[T0, T2, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T2, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T2, error) {
+ return fn.(extractOutput1x2[T0, T2]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T2)
+ return &caller1x1[T0, T2]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T2)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T2 {
+ return fn.(extractOutput1x1[T0, T2]).ExtractOutput(a0)
+ })
+ }
+ }
+
+ if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+ }
+
+ wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+ m := map[string]reflectx.Func{}
+ if mergeAccumulatorsWrapper != nil {
+ m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+ }
+ if createAccumulatorWrapper != nil {
+ m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+ }
+ if addInputWrapper != nil {
+ m["AddInput"] = addInputWrapper(fn)
+ }
+ if extractOutputWrapper != nil {
+ m["ExtractOutput"] = extractOutputWrapper(fn)
+ }
+
+ return m
+ }
+ reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+func registerCombinerTypes(accum interface{}) {
+ // Register the combiner
+ runtime.RegisterType(reflect.TypeOf(accum).Elem())
+ schema.RegisterType(reflect.TypeOf(accum).Elem())
+
+ // Register all types in the Combiner.
+ // There may be different types across MergeAccumulators, AddInput, and ExtractOutput.
+ accumVal := reflect.ValueOf(accum)
+ registerMethodTypes(accumVal.MethodByName("MergeAccumulators").Type())
+ if m := accumVal.MethodByName("AddInput"); m.IsValid() {
+ registerMethodTypes(m.Type())
+ }
+ if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() {
+ registerMethodTypes(m.Type())
+ }
+}
+
func registerDoFnTypes(doFn interface{}) {
// Register the doFn
runtime.RegisterType(reflect.TypeOf(doFn).Elem())
schema.RegisterType(reflect.TypeOf(doFn).Elem())
// Register all types in the DoFn
- fn := reflect.ValueOf(doFn).MethodByName("ProcessElement").Type()
+ registerMethodTypes(reflect.ValueOf(doFn).MethodByName("ProcessElement").Type())
+}
+
+func registerMethodTypes(fn reflect.Type) {
for i := 0; i < fn.NumIn(); i++ {
in := reflectx.SkipPtr(fn.In(i))
if in.Kind() == reflect.Struct {
diff --git a/sdks/go/pkg/beam/registration/registration.tmpl b/sdks/go/pkg/beam/registration/registration.tmpl
index e046a9d165f..f73c0e323e2 100644
--- a/sdks/go/pkg/beam/registration/registration.tmpl
+++ b/sdks/go/pkg/beam/registration/registration.tmpl
@@ -118,6 +118,7 @@ package registration
import (
"context"
+ "fmt"
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
@@ -241,13 +242,320 @@ type teardown1x1 interface {
Teardown(ctx context.Context) error
}
+type createAccumulator0x1[T any] interface {
+ CreateAccumulator() T
+}
+
+type createAccumulator0x2[T any] interface {
+ CreateAccumulator() (T, error)
+}
+
+type addInput2x1[T1, T2 any] interface {
+ AddInput(a T1, i T2) T1
+}
+
+type addInput2x2[T1, T2 any] interface {
+ AddInput(a T1, i T2) (T1, error)
+}
+
+type mergeAccumulators2x1[T any] interface {
+ MergeAccumulators(a0 T, a1 T) T
+}
+
+type mergeAccumulators2x2[T any] interface {
+ MergeAccumulators(a0 T, a1 T) (T, error)
+}
+
+type extractOutput1x1[T1, T2 any] interface {
+ ExtractOutput(a T1) T2
+}
+
+type extractOutput1x2[T1, T2 any] interface {
+ ExtractOutput(a T1) (T2, error)
+}
+
+{{range $accum := upto 3}}{{$genericParams := (add $accum 1)}}
+// Combiner{{$genericParams}} registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+{{if (eq $genericParams 1)}}// Combiner1 should be used when your accumulator, input, and output are all of the same type.
+// It can be called with register.Combiner1[T](&CustomCombiner{})
+// where T is the type of the input/accumulator/output.
+{{else}}{{if (eq $genericParams 2)}}// Combiner2 should be used when your accumulator, input, and output are 2 distinct types.
+// It can be called with register.Combiner2[T1, T2](&CustomCombiner{})
+// where T1 is the type of the accumulator and T2 is the other type.
+{{else}}// Combiner3 should be used when your accumulator, input, and output are 3 distinct types.
+// It can be called with register.Combiner3[T1, T2, T3](&CustomCombiner{})
+// where T1 is the type of the accumulator, T2 is the type of the input, and T3 is the type of the output.
+{{end}}{{end}}func Combiner{{$genericParams}}[{{range $paramNum := upto $genericParams}}{{if $paramNum}}, {{end}}T{{$paramNum}}{{end}} any](accum interface{}) {
+ registerCombinerTypes(accum)
+ accumVal := reflect.ValueOf(accum)
+ var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+ })
+ }
+ }
+
+ if mergeAccumulatorsWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+ }
+
+ var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(createAccumulator0x2[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() (T0, error))
+ return &caller0x2[T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() (T0, error) {
+ return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+ })
+ }
+ } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func() T0)
+ return &caller0x1[T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+ createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func() T0 {
+ return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+ })
+ }
+ }
+ if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+ }
+
+ var addInputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(addInput2x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) (T0, error))
+ return &caller2x2[T0, T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+ return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T0) T0)
+ return &caller2x1[T0, T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+ return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+ })
+ }
+ } {{if (gt $genericParams 1)}} else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T1) (T0, error))
+ return &caller2x2[T0, T1, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) {
+ return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T1) T0)
+ return &caller2x1[T0, T1, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+ return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+ })
+ }
+ } {{end}}{{if (gt $genericParams 2)}} else if _, ok := accum.(addInput2x2[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T2) (T0, error))
+ return &caller2x2[T0, T2, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) (T0, error))(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T2) (T0, error) {
+ return fn.(addInput2x2[T0, T2]).AddInput(a0, a1)
+ })
+ }
+ } else if _, ok := accum.(addInput2x1[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0, T2) T0)
+ return &caller2x1[T0, T2, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) T0)(nil)).Elem(), caller)
+
+ addInputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0, a1 T2) T0 {
+ return fn.(addInput2x1[T0, T2]).AddInput(a0, a1)
+ })
+ }
+ } {{end}}
+
+ if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+ }
+
+ var extractOutputWrapper func(fn interface{}) reflectx.Func
+ if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T0, error))
+ return &caller1x2[T0, T0, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+ return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T0)
+ return &caller1x1[T0, T0]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T0 {
+ return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+ })
+ }
+ } {{if (gt $genericParams 1)}} else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T1, error))
+ return &caller1x2[T0, T1, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+ return fn.(extractOutput1x2[T0, T1]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T1)
+ return &caller1x1[T0, T1]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T1)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T1 {
+ return fn.(extractOutput1x1[T0, T1]).ExtractOutput(a0)
+ })
+ }
+ } {{end}}{{if (gt $genericParams 2)}} else if _, ok := accum.(extractOutput1x2[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) (T2, error))
+ return &caller1x2[T0, T2, error]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T2, error))(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) (T2, error) {
+ return fn.(extractOutput1x2[T0, T2]).ExtractOutput(a0)
+ })
+ }
+ } else if _, ok := accum.(extractOutput1x1[T0, T2]); ok {
+ caller := func(fn interface{}) reflectx.Func {
+ f := fn.(func(T0) T2)
+ return &caller1x1[T0, T2]{fn: f}
+ }
+ reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T2)(nil)).Elem(), caller)
+
+ extractOutputWrapper = func(fn interface{}) reflectx.Func {
+ return reflectx.MakeFunc(func(a0 T0) T2 {
+ return fn.(extractOutput1x1[T0, T2]).ExtractOutput(a0)
+ })
+ }
+ } {{end}}
+
+ if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+ panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+ }
+
+ wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+ m := map[string]reflectx.Func{}
+ if mergeAccumulatorsWrapper != nil {
+ m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+ }
+ if createAccumulatorWrapper != nil {
+ m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+ }
+ if addInputWrapper != nil {
+ m["AddInput"] = addInputWrapper(fn)
+ }
+ if extractOutputWrapper != nil {
+ m["ExtractOutput"] = extractOutputWrapper(fn)
+ }
+
+ return m
+ }
+ reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}{{end}}
+
+func registerCombinerTypes(accum interface{}) {
+ // Register the combiner
+ runtime.RegisterType(reflect.TypeOf(accum).Elem())
+ schema.RegisterType(reflect.TypeOf(accum).Elem())
+
+ // Register all types in the Combiner.
+ // There may be different types across MergeAccumulators, AddInput, and ExtractOutput.
+ accumVal := reflect.ValueOf(accum)
+ registerMethodTypes(accumVal.MethodByName("MergeAccumulators").Type())
+ if m := accumVal.MethodByName("AddInput"); m.IsValid() {
+ registerMethodTypes(m.Type())
+ }
+ if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() {
+ registerMethodTypes(m.Type())
+ }
+}
+
func registerDoFnTypes(doFn interface{}) {
// Register the doFn
runtime.RegisterType(reflect.TypeOf(doFn).Elem())
schema.RegisterType(reflect.TypeOf(doFn).Elem())
// Register all types in the DoFn
- fn := reflect.ValueOf(doFn).MethodByName("ProcessElement").Type()
+ registerMethodTypes(reflect.ValueOf(doFn).MethodByName("ProcessElement").Type())
+}
+
+func registerMethodTypes(fn reflect.Type) {
for i := 0; i < fn.NumIn(); i++ {
in := reflectx.SkipPtr(fn.In(i))
if in.Kind() == reflect.Struct {
diff --git a/sdks/go/pkg/beam/registration/registration_test.go b/sdks/go/pkg/beam/registration/registration_test.go
index 4f5c37dd2a8..2b535692ff2 100644
--- a/sdks/go/pkg/beam/registration/registration_test.go
+++ b/sdks/go/pkg/beam/registration/registration_test.go
@@ -159,6 +159,168 @@ func TestRegister_RegistersTypes(t *testing.T) {
}
}
+func TestCombiner_CompleteCombiner3(t *testing.T) {
+ accum := &CompleteCombiner3{}
+ Combiner3[int, CustomType, CustomType2](accum)
+
+ m, ok := reflectx.WrapMethods(&CompleteCombiner3{})
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found")
+ }
+ ca, ok := m["CreateAccumulator"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for CreateAccumulator")
+ }
+ if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+ t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+ }
+ ai, ok := m["AddInput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for AddInput")
+ }
+ if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 5; got != want {
+ t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+ }
+ ma, ok := m["MergeAccumulators"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+ eo, ok := m["ExtractOutput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := eo.Call([]interface{}{2})[0].(CustomType2).val2, 2; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+}
+
+func TestCombiner_RegistersTypes(t *testing.T) {
+ accum := &CompleteCombiner3{}
+ Combiner3[int, CustomType, CustomType2](accum)
+
+ // Need to call FromType so that the registry will reconcile its registrations
+ schema.FromType(reflect.TypeOf(accum).Elem())
+ if !schema.Registered(reflect.TypeOf(accum).Elem()) {
+ t.Errorf("schema.Registered(reflect.TypeOf(CustomTypeDoFn1x1{})) = false, want true")
+ }
+ if !schema.Registered(reflect.TypeOf(CustomType{})) {
+ t.Errorf("schema.Registered(reflect.TypeOf(CustomType{})) = false, want true")
+ }
+ if !schema.Registered(reflect.TypeOf(CustomType2{})) {
+ t.Errorf("schema.Registered(reflect.TypeOf(CustomType{})) = false, want true")
+ }
+}
+
+func TestCombiner_CompleteCombiner2(t *testing.T) {
+ accum := &CompleteCombiner2{}
+ Combiner2[int, CustomType](accum)
+
+ m, ok := reflectx.WrapMethods(&CompleteCombiner2{})
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found")
+ }
+ ca, ok := m["CreateAccumulator"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for CreateAccumulator")
+ }
+ if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+ t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+ }
+ ai, ok := m["AddInput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for AddInput")
+ }
+ if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 5; got != want {
+ t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+ }
+ ma, ok := m["MergeAccumulators"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+ eo, ok := m["ExtractOutput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := eo.Call([]interface{}{2})[0].(CustomType).val, 2; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+}
+
+func TestCombiner_CompleteCombiner1(t *testing.T) {
+ accum := &CompleteCombiner1{}
+ Combiner1[int](accum)
+
+ m, ok := reflectx.WrapMethods(&CompleteCombiner1{})
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found")
+ }
+ ca, ok := m["CreateAccumulator"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for CreateAccumulator")
+ }
+ if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+ t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+ }
+ ai, ok := m["AddInput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for AddInput")
+ }
+ if got, want := ai.Call([]interface{}{2, 3})[0].(int), 5; got != want {
+ t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+ }
+ ma, ok := m["MergeAccumulators"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+ eo, ok := m["ExtractOutput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := eo.Call([]interface{}{2})[0].(int), 2; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+}
+
+func TestCombiner_PartialCombiner2(t *testing.T) {
+ accum := &PartialCombiner2{}
+ Combiner2[int, CustomType](accum)
+
+ m, ok := reflectx.WrapMethods(&PartialCombiner2{})
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found")
+ }
+ ca, ok := m["CreateAccumulator"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found for CreateAccumulator")
+ }
+ if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+ t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+ }
+ ai, ok := m["AddInput"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found for AddInput")
+ }
+ if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 5; got != want {
+ t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+ }
+ ma, ok := m["MergeAccumulators"]
+ if !ok {
+ t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found for MergeAccumulators")
+ }
+ if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+ t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+ }
+}
+
func TestEmitter1(t *testing.T) {
Emitter1[int]()
if !exec.IsEmitterRegistered(reflect.TypeOf((*func(int))(nil)).Elem()) {
@@ -438,3 +600,75 @@ type CustomTypeDoFn1x1 struct {
func (fn *CustomTypeDoFn1x1) ProcessElement(t CustomType) CustomType2 {
return CustomType2{val2: t.val}
}
+
+type CompleteCombiner3 struct {
+}
+
+func (fn *CompleteCombiner3) CreateAccumulator() int {
+ return 0
+}
+
+func (fn *CompleteCombiner3) AddInput(i int, c CustomType) int {
+ return i + c.val
+}
+
+func (fn *CompleteCombiner3) MergeAccumulators(i1 int, i2 int) int {
+ return i1 + i2
+}
+
+func (fn *CompleteCombiner3) ExtractOutput(i int) CustomType2 {
+ return CustomType2{val2: i}
+}
+
+type CompleteCombiner2 struct {
+}
+
+func (fn *CompleteCombiner2) CreateAccumulator() int {
+ return 0
+}
+
+func (fn *CompleteCombiner2) AddInput(i int, c CustomType) int {
+ return i + c.val
+}
+
+func (fn *CompleteCombiner2) MergeAccumulators(i1 int, i2 int) int {
+ return i1 + i2
+}
+
+func (fn *CompleteCombiner2) ExtractOutput(i int) CustomType {
+ return CustomType{val: i}
+}
+
+type CompleteCombiner1 struct {
+}
+
+func (fn *CompleteCombiner1) CreateAccumulator() int {
+ return 0
+}
+
+func (fn *CompleteCombiner1) AddInput(i1 int, i2 int) int {
+ return i1 + i2
+}
+
+func (fn *CompleteCombiner1) MergeAccumulators(i1 int, i2 int) int {
+ return i1 + i2
+}
+
+func (fn *CompleteCombiner1) ExtractOutput(i int) int {
+ return i
+}
+
+type PartialCombiner2 struct {
+}
+
+func (fn *PartialCombiner2) CreateAccumulator() int {
+ return 0
+}
+
+func (fn *PartialCombiner2) AddInput(i int, c CustomType) int {
+ return i + c.val
+}
+
+func (fn *PartialCombiner2) MergeAccumulators(i1 int, i2 int) int {
+ return i1 + i2
+}