You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/05/10 00:05:31 UTC

[beam] branch master updated: [BEAM-14347] Add generic registration for accumulators (#17579)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a5fbc8ed163 [BEAM-14347] Add generic registration for accumulators (#17579)
a5fbc8ed163 is described below

commit a5fbc8ed163339879547a8e16e4f4f99a5b3a388
Author: Danny McCormick <da...@google.com>
AuthorDate: Mon May 9 20:05:24 2022 -0400

    [BEAM-14347] Add generic registration for accumulators (#17579)
---
 sdks/go/pkg/beam/registration/registration.go      | 655 ++++++++++++++++++++-
 sdks/go/pkg/beam/registration/registration.tmpl    | 310 +++++++++-
 sdks/go/pkg/beam/registration/registration_test.go | 234 ++++++++
 3 files changed, 1197 insertions(+), 2 deletions(-)

diff --git a/sdks/go/pkg/beam/registration/registration.go b/sdks/go/pkg/beam/registration/registration.go
index ba1585d71bf..95057a428ae 100644
--- a/sdks/go/pkg/beam/registration/registration.go
+++ b/sdks/go/pkg/beam/registration/registration.go
@@ -21,6 +21,7 @@ package registration
 
 import (
 	"context"
+	"fmt"
 	"reflect"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
@@ -7030,13 +7031,665 @@ type teardown1x1 interface {
 	Teardown(ctx context.Context) error
 }
 
+type createAccumulator0x1[T any] interface {
+	CreateAccumulator() T
+}
+
+type createAccumulator0x2[T any] interface {
+	CreateAccumulator() (T, error)
+}
+
+type addInput2x1[T1, T2 any] interface {
+	AddInput(a T1, i T2) T1
+}
+
+type addInput2x2[T1, T2 any] interface {
+	AddInput(a T1, i T2) (T1, error)
+}
+
+type mergeAccumulators2x1[T any] interface {
+	MergeAccumulators(a0 T, a1 T) T
+}
+
+type mergeAccumulators2x2[T any] interface {
+	MergeAccumulators(a0 T, a1 T) (T, error)
+}
+
+type extractOutput1x1[T1, T2 any] interface {
+	ExtractOutput(a T1) T2
+}
+
+type extractOutput1x2[T1, T2 any] interface {
+	ExtractOutput(a T1) (T2, error)
+}
+
+// Combiner1 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+// Combiner1 should be used when your accumulator, input, and output are all of the same type.
+// It can be called with register.Combiner1[T](&CustomCombiner{})
+// where T is the type of the input/accumulator/output.
+func Combiner1[T0 any](accum interface{}) {
+	registerCombinerTypes(accum)
+	accumVal := reflect.ValueOf(accum)
+	var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) (T0, error))
+			return &caller2x2[T0, T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+		mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) T0)
+			return &caller2x1[T0, T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+		mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+	}
+
+	if mergeAccumulatorsWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+	}
+
+	var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(createAccumulator0x2[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func() (T0, error))
+			return &caller0x2[T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+		createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() (T0, error) {
+				return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+			})
+		}
+	} else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func() T0)
+			return &caller0x1[T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+		createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() T0 {
+				return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+			})
+		}
+	}
+	if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+	}
+
+	var addInputWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(addInput2x2[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) (T0, error))
+			return &caller2x2[T0, T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) T0)
+			return &caller2x1[T0, T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+			})
+		}
+	}
+
+	if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+	}
+
+	var extractOutputWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) (T0, error))
+			return &caller1x2[T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+				return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) T0)
+			return &caller1x1[T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T0 {
+				return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+			})
+		}
+	}
+
+	if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+	}
+
+	wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+		m := map[string]reflectx.Func{}
+		if mergeAccumulatorsWrapper != nil {
+			m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+		}
+		if createAccumulatorWrapper != nil {
+			m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+		}
+		if addInputWrapper != nil {
+			m["AddInput"] = addInputWrapper(fn)
+		}
+		if extractOutputWrapper != nil {
+			m["ExtractOutput"] = extractOutputWrapper(fn)
+		}
+
+		return m
+	}
+	reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+// Combiner2 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+// Combiner2 should be used when your accumulator, input, and output are 2 distinct types.
+// It can be called with register.Combiner2[T1, T2](&CustomCombiner{})
+// where T1 is the type of the accumulator and T2 is the other type.
+func Combiner2[T0, T1 any](accum interface{}) {
+	registerCombinerTypes(accum)
+	accumVal := reflect.ValueOf(accum)
+	var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) (T0, error))
+			return &caller2x2[T0, T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+		mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) T0)
+			return &caller2x1[T0, T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+		mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+	}
+
+	if mergeAccumulatorsWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+	}
+
+	var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(createAccumulator0x2[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func() (T0, error))
+			return &caller0x2[T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+		createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() (T0, error) {
+				return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+			})
+		}
+	} else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func() T0)
+			return &caller0x1[T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+		createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() T0 {
+				return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+			})
+		}
+	}
+	if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+	}
+
+	var addInputWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(addInput2x2[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) (T0, error))
+			return &caller2x2[T0, T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) T0)
+			return &caller2x1[T0, T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T1) (T0, error))
+			return &caller2x2[T0, T1, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, error))(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) {
+				return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T1) T0)
+			return &caller2x1[T0, T1, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) T0)(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+				return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+			})
+		}
+	}
+
+	if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+	}
+
+	var extractOutputWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) (T0, error))
+			return &caller1x2[T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+				return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) T0)
+			return &caller1x1[T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T0 {
+				return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) (T1, error))
+			return &caller1x2[T0, T1, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, error))(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+				return fn.(extractOutput1x2[T0, T1]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) T1)
+			return &caller1x1[T0, T1]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T1)(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T1 {
+				return fn.(extractOutput1x1[T0, T1]).ExtractOutput(a0)
+			})
+		}
+	}
+
+	if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+	}
+
+	wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+		m := map[string]reflectx.Func{}
+		if mergeAccumulatorsWrapper != nil {
+			m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+		}
+		if createAccumulatorWrapper != nil {
+			m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+		}
+		if addInputWrapper != nil {
+			m["AddInput"] = addInputWrapper(fn)
+		}
+		if extractOutputWrapper != nil {
+			m["ExtractOutput"] = extractOutputWrapper(fn)
+		}
+
+		return m
+	}
+	reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+// Combiner3 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+// Combiner3 should be used when your accumulator, input, and output are 3 distinct types.
+// It can be called with register.Combiner3[T1, T2, T3](&CustomCombiner{})
+// where T1 is the type of the accumulator, T2 is the type of the input, and T3 is the type of the output.
+func Combiner3[T0, T1, T2 any](accum interface{}) {
+	registerCombinerTypes(accum)
+	accumVal := reflect.ValueOf(accum)
+	var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) (T0, error))
+			return &caller2x2[T0, T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+		mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) T0)
+			return &caller2x1[T0, T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+		mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+	}
+
+	if mergeAccumulatorsWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+	}
+
+	var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(createAccumulator0x2[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func() (T0, error))
+			return &caller0x2[T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+		createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() (T0, error) {
+				return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+			})
+		}
+	} else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func() T0)
+			return &caller0x1[T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+		createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() T0 {
+				return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+			})
+		}
+	}
+	if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+	}
+
+	var addInputWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(addInput2x2[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) (T0, error))
+			return &caller2x2[T0, T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T0) T0)
+			return &caller2x1[T0, T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T1) (T0, error))
+			return &caller2x2[T0, T1, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, error))(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) {
+				return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T1) T0)
+			return &caller2x1[T0, T1, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) T0)(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+				return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x2[T0, T2]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T2) (T0, error))
+			return &caller2x2[T0, T2, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) (T0, error))(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T2) (T0, error) {
+				return fn.(addInput2x2[T0, T2]).AddInput(a0, a1)
+			})
+		}
+	} else if _, ok := accum.(addInput2x1[T0, T2]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0, T2) T0)
+			return &caller2x1[T0, T2, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) T0)(nil)).Elem(), caller)
+
+		addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T2) T0 {
+				return fn.(addInput2x1[T0, T2]).AddInput(a0, a1)
+			})
+		}
+	}
+
+	if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+	}
+
+	var extractOutputWrapper func(fn interface{}) reflectx.Func
+	if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) (T0, error))
+			return &caller1x2[T0, T0, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+				return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) T0)
+			return &caller1x1[T0, T0]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T0 {
+				return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) (T1, error))
+			return &caller1x2[T0, T1, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, error))(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+				return fn.(extractOutput1x2[T0, T1]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) T1)
+			return &caller1x1[T0, T1]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T1)(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T1 {
+				return fn.(extractOutput1x1[T0, T1]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x2[T0, T2]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) (T2, error))
+			return &caller1x2[T0, T2, error]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T2, error))(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T2, error) {
+				return fn.(extractOutput1x2[T0, T2]).ExtractOutput(a0)
+			})
+		}
+	} else if _, ok := accum.(extractOutput1x1[T0, T2]); ok {
+		caller := func(fn interface{}) reflectx.Func {
+			f := fn.(func(T0) T2)
+			return &caller1x1[T0, T2]{fn: f}
+		}
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T2)(nil)).Elem(), caller)
+
+		extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T2 {
+				return fn.(extractOutput1x1[T0, T2]).ExtractOutput(a0)
+			})
+		}
+	}
+
+	if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+		panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+	}
+
+	wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+		m := map[string]reflectx.Func{}
+		if mergeAccumulatorsWrapper != nil {
+			m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+		}
+		if createAccumulatorWrapper != nil {
+			m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+		}
+		if addInputWrapper != nil {
+			m["AddInput"] = addInputWrapper(fn)
+		}
+		if extractOutputWrapper != nil {
+			m["ExtractOutput"] = extractOutputWrapper(fn)
+		}
+
+		return m
+	}
+	reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+func registerCombinerTypes(accum interface{}) {
+	// Register the combiner
+	runtime.RegisterType(reflect.TypeOf(accum).Elem())
+	schema.RegisterType(reflect.TypeOf(accum).Elem())
+
+	// Register all types in the Combiner.
+	// There may be different types across MergeAccumulators, AddInput, and ExtractOutput.
+	accumVal := reflect.ValueOf(accum)
+	registerMethodTypes(accumVal.MethodByName("MergeAccumulators").Type())
+	if m := accumVal.MethodByName("AddInput"); m.IsValid() {
+		registerMethodTypes(m.Type())
+	}
+	if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() {
+		registerMethodTypes(m.Type())
+	}
+}
+
 func registerDoFnTypes(doFn interface{}) {
 	// Register the doFn
 	runtime.RegisterType(reflect.TypeOf(doFn).Elem())
 	schema.RegisterType(reflect.TypeOf(doFn).Elem())
 
 	// Register all types in the DoFn
-	fn := reflect.ValueOf(doFn).MethodByName("ProcessElement").Type()
+	registerMethodTypes(reflect.ValueOf(doFn).MethodByName("ProcessElement").Type())
+}
+
+func registerMethodTypes(fn reflect.Type) {
 	for i := 0; i < fn.NumIn(); i++ {
 		in := reflectx.SkipPtr(fn.In(i))
 		if in.Kind() == reflect.Struct {
diff --git a/sdks/go/pkg/beam/registration/registration.tmpl b/sdks/go/pkg/beam/registration/registration.tmpl
index e046a9d165f..f73c0e323e2 100644
--- a/sdks/go/pkg/beam/registration/registration.tmpl
+++ b/sdks/go/pkg/beam/registration/registration.tmpl
@@ -118,6 +118,7 @@ package registration
 
 import (
 	"context"
+    "fmt"
 	"reflect"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
@@ -241,13 +242,320 @@ type teardown1x1 interface {
 	Teardown(ctx context.Context) error
 }
 
+type createAccumulator0x1[T any] interface {
+    CreateAccumulator() T
+}
+
+type createAccumulator0x2[T any] interface {
+    CreateAccumulator() (T, error)
+}
+
+type addInput2x1[T1, T2 any] interface {
+    AddInput(a T1, i T2) T1
+}
+
+type addInput2x2[T1, T2 any] interface {
+    AddInput(a T1, i T2) (T1, error)
+}
+
+type mergeAccumulators2x1[T any] interface {
+    MergeAccumulators(a0 T, a1 T) T
+}
+
+type mergeAccumulators2x2[T any] interface {
+    MergeAccumulators(a0 T, a1 T) (T, error)
+}
+
+type extractOutput1x1[T1, T2 any] interface {
+    ExtractOutput(a T1) T2
+}
+
+type extractOutput1x2[T1, T2 any] interface {
+    ExtractOutput(a T1) (T2, error)
+}
+
+{{range $accum := upto 3}}{{$genericParams := (add $accum 1)}}
+// Combiner{{$genericParams}} registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different Combiner
+// functions, each of which should be used for a different situation.
+{{if (eq $genericParams 1)}}// Combiner1 should be used when your accumulator, input, and output are all of the same type.
+// It can be called with register.Combiner1[T](&CustomCombiner{})
+// where T is the type of the input/accumulator/output.
+{{else}}{{if (eq $genericParams 2)}}// Combiner2 should be used when your accumulator, input, and output are 2 distinct types.
+// It can be called with register.Combiner2[T1, T2](&CustomCombiner{})
+// where T1 is the type of the accumulator and T2 is the other type.
+{{else}}// Combiner3 should be used when your accumulator, input, and output are 3 distinct types.
+// It can be called with register.Combiner3[T1, T2, T3](&CustomCombiner{})
+// where T1 is the type of the accumulator, T2 is the type of the input, and T3 is the type of the output.
+{{end}}{{end}}func Combiner{{$genericParams}}[{{range $paramNum := upto $genericParams}}{{if $paramNum}}, {{end}}T{{$paramNum}}{{end}} any](accum interface{}) {
+    registerCombinerTypes(accum)
+    accumVal := reflect.ValueOf(accum)
+    var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) (T0, error))
+            return &caller2x2[T0, T0, T0, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+        mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+    } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) T0)
+            return &caller2x1[T0, T0, T0]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+        mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+			})
+		}
+    }
+
+    if mergeAccumulatorsWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner %v. Failed to infer types", accum))
+    }
+
+    var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(createAccumulator0x2[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func() (T0, error))
+            return &caller0x2[T0, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, error))(nil)).Elem(), caller)
+
+        createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() (T0, error) {
+				return fn.(createAccumulator0x2[T0]).CreateAccumulator()
+			})
+		}
+    } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func() T0)
+            return &caller0x1[T0]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), caller)
+
+        createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func() T0 {
+				return fn.(createAccumulator0x1[T0]).CreateAccumulator()
+			})
+		}
+    }
+    if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && createAccumulatorWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner %v. Failed to infer types", accum))
+    }
+
+    var addInputWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(addInput2x2[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) (T0, error))
+            return &caller2x2[T0, T0, T0, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, error))(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) {
+				return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+			})
+		}
+    } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) T0)
+            return &caller2x1[T0, T0, T0]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) T0)(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+				return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+			})
+		}
+    } {{if (gt $genericParams 1)}} else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T1) (T0, error))
+            return &caller2x2[T0, T1, T0, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, error))(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) {
+				return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+			})
+		}
+    } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T1) T0)
+            return &caller2x1[T0, T1, T0]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) T0)(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+				return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+			})
+		}
+    } {{end}}{{if (gt $genericParams 2)}} else if _, ok := accum.(addInput2x2[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T2) (T0, error))
+            return &caller2x2[T0, T2, T0, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) (T0, error))(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T2) (T0, error) {
+				return fn.(addInput2x2[T0, T2]).AddInput(a0, a1)
+			})
+		}
+    } else if _, ok := accum.(addInput2x1[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T2) T0)
+            return &caller2x1[T0, T2, T0]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) T0)(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0, a1 T2) T0 {
+				return fn.(addInput2x1[T0, T2]).AddInput(a0, a1)
+			})
+		}
+    } {{end}}
+
+    if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed to infer types", accum))
+    }
+
+    var extractOutputWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) (T0, error))
+            return &caller1x2[T0, T0, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, error))(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+				return fn.(extractOutput1x2[T0, T0]).ExtractOutput(a0)
+			})
+		}
+    } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) T0)
+            return &caller1x1[T0, T0]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T0)(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T0 {
+				return fn.(extractOutput1x1[T0, T0]).ExtractOutput(a0)
+			})
+		}
+    } {{if (gt $genericParams 1)}} else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) (T1, error))
+            return &caller1x2[T0, T1, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, error))(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+				return fn.(extractOutput1x2[T0, T1]).ExtractOutput(a0)
+			})
+		}
+    } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) T1)
+            return &caller1x1[T0, T1]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T1)(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T1 {
+				return fn.(extractOutput1x1[T0, T1]).ExtractOutput(a0)
+			})
+		}
+    } {{end}}{{if (gt $genericParams 2)}} else if _, ok := accum.(extractOutput1x2[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) (T2, error))
+            return &caller1x2[T0, T2, error]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T2, error))(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) (T2, error) {
+				return fn.(extractOutput1x2[T0, T2]).ExtractOutput(a0)
+			})
+		}
+    } else if _, ok := accum.(extractOutput1x1[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) T2)
+            return &caller1x1[T0, T2]{fn: f}
+        }
+		reflectx.RegisterFunc(reflect.TypeOf((*func(T0) T2)(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+			return reflectx.MakeFunc(func(a0 T0) T2 {
+				return fn.(extractOutput1x1[T0, T2]).ExtractOutput(a0)
+			})
+		}
+    } {{end}}
+
+    if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && extractOutputWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. Failed to infer types", accum))
+    }
+
+	wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+		m := map[string]reflectx.Func{}
+		if mergeAccumulatorsWrapper != nil {
+			m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+		}
+		if createAccumulatorWrapper != nil {
+			m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+		}
+		if addInputWrapper != nil {
+			m["AddInput"] = addInputWrapper(fn)
+		}
+		if extractOutputWrapper != nil {
+			m["ExtractOutput"] = extractOutputWrapper(fn)
+		}
+
+		return m
+	}
+	reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}{{end}}
+
+func registerCombinerTypes(accum interface{}) {
+    // Register the combiner
+    runtime.RegisterType(reflect.TypeOf(accum).Elem())
+    schema.RegisterType(reflect.TypeOf(accum).Elem())
+
+    // Register all types in the Combiner.
+    // There may be different types across MergeAccumulators, AddInput, and ExtractOutput.
+    accumVal := reflect.ValueOf(accum)
+    registerMethodTypes(accumVal.MethodByName("MergeAccumulators").Type())
+    if m := accumVal.MethodByName("AddInput"); m.IsValid() {
+        registerMethodTypes(m.Type())
+    }
+    if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() {
+        registerMethodTypes(m.Type())
+    }
+}
+
 func registerDoFnTypes(doFn interface{}) {
    // Register the doFn
    runtime.RegisterType(reflect.TypeOf(doFn).Elem())
    schema.RegisterType(reflect.TypeOf(doFn).Elem())
   
    // Register all types in the DoFn
-   fn := reflect.ValueOf(doFn).MethodByName("ProcessElement").Type()
+   registerMethodTypes(reflect.ValueOf(doFn).MethodByName("ProcessElement").Type())
+}
+
+func registerMethodTypes(fn reflect.Type) {
    for i := 0; i < fn.NumIn(); i++ {
        in := reflectx.SkipPtr(fn.In(i))
        if in.Kind() == reflect.Struct {
diff --git a/sdks/go/pkg/beam/registration/registration_test.go b/sdks/go/pkg/beam/registration/registration_test.go
index 4f5c37dd2a8..2b535692ff2 100644
--- a/sdks/go/pkg/beam/registration/registration_test.go
+++ b/sdks/go/pkg/beam/registration/registration_test.go
@@ -159,6 +159,168 @@ func TestRegister_RegistersTypes(t *testing.T) {
 	}
 }
 
+func TestCombiner_CompleteCombiner3(t *testing.T) {
+	accum := &CompleteCombiner3{}
+	Combiner3[int, CustomType, CustomType2](accum)
+
+	m, ok := reflectx.WrapMethods(&CompleteCombiner3{})
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found")
+	}
+	ca, ok := m["CreateAccumulator"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for CreateAccumulator")
+	}
+	if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+		t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+	}
+	ai, ok := m["AddInput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for AddInput")
+	}
+	if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 5; got != want {
+		t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+	}
+	ma, ok := m["MergeAccumulators"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+	eo, ok := m["ExtractOutput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := eo.Call([]interface{}{2})[0].(CustomType2).val2, 2; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+}
+
+func TestCombiner_RegistersTypes(t *testing.T) {
+	accum := &CompleteCombiner3{}
+	Combiner3[int, CustomType, CustomType2](accum)
+
+	// Need to call FromType so that the registry will reconcile its registrations
+	schema.FromType(reflect.TypeOf(accum).Elem())
+	if !schema.Registered(reflect.TypeOf(accum).Elem()) {
+		t.Errorf("schema.Registered(reflect.TypeOf(CustomTypeDoFn1x1{})) = false, want true")
+	}
+	if !schema.Registered(reflect.TypeOf(CustomType{})) {
+		t.Errorf("schema.Registered(reflect.TypeOf(CustomType{})) = false, want true")
+	}
+	if !schema.Registered(reflect.TypeOf(CustomType2{})) {
+		t.Errorf("schema.Registered(reflect.TypeOf(CustomType{})) = false, want true")
+	}
+}
+
+func TestCombiner_CompleteCombiner2(t *testing.T) {
+	accum := &CompleteCombiner2{}
+	Combiner2[int, CustomType](accum)
+
+	m, ok := reflectx.WrapMethods(&CompleteCombiner2{})
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found")
+	}
+	ca, ok := m["CreateAccumulator"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for CreateAccumulator")
+	}
+	if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+		t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+	}
+	ai, ok := m["AddInput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for AddInput")
+	}
+	if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 5; got != want {
+		t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+	}
+	ma, ok := m["MergeAccumulators"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+	eo, ok := m["ExtractOutput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := eo.Call([]interface{}{2})[0].(CustomType).val, 2; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+}
+
+func TestCombiner_CompleteCombiner1(t *testing.T) {
+	accum := &CompleteCombiner1{}
+	Combiner1[int](accum)
+
+	m, ok := reflectx.WrapMethods(&CompleteCombiner1{})
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found")
+	}
+	ca, ok := m["CreateAccumulator"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for CreateAccumulator")
+	}
+	if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+		t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+	}
+	ai, ok := m["AddInput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for AddInput")
+	}
+	if got, want := ai.Call([]interface{}{2, 3})[0].(int), 5; got != want {
+		t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+	}
+	ma, ok := m["MergeAccumulators"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+	eo, ok := m["ExtractOutput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := eo.Call([]interface{}{2})[0].(int), 2; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+}
+
+func TestCombiner_PartialCombiner2(t *testing.T) {
+	accum := &PartialCombiner2{}
+	Combiner2[int, CustomType](accum)
+
+	m, ok := reflectx.WrapMethods(&PartialCombiner2{})
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found")
+	}
+	ca, ok := m["CreateAccumulator"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found for CreateAccumulator")
+	}
+	if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+		t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", got, want)
+	}
+	ai, ok := m["AddInput"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found for AddInput")
+	}
+	if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 5; got != want {
+		t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+	}
+	ma, ok := m["MergeAccumulators"]
+	if !ok {
+		t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no registered entry found for MergeAccumulators")
+	}
+	if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+		t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", got, want)
+	}
+}
+
 func TestEmitter1(t *testing.T) {
 	Emitter1[int]()
 	if !exec.IsEmitterRegistered(reflect.TypeOf((*func(int))(nil)).Elem()) {
@@ -438,3 +600,75 @@ type CustomTypeDoFn1x1 struct {
 func (fn *CustomTypeDoFn1x1) ProcessElement(t CustomType) CustomType2 {
 	return CustomType2{val2: t.val}
 }
+
+type CompleteCombiner3 struct {
+}
+
+func (fn *CompleteCombiner3) CreateAccumulator() int {
+	return 0
+}
+
+func (fn *CompleteCombiner3) AddInput(i int, c CustomType) int {
+	return i + c.val
+}
+
+func (fn *CompleteCombiner3) MergeAccumulators(i1 int, i2 int) int {
+	return i1 + i2
+}
+
+func (fn *CompleteCombiner3) ExtractOutput(i int) CustomType2 {
+	return CustomType2{val2: i}
+}
+
+type CompleteCombiner2 struct {
+}
+
+func (fn *CompleteCombiner2) CreateAccumulator() int {
+	return 0
+}
+
+func (fn *CompleteCombiner2) AddInput(i int, c CustomType) int {
+	return i + c.val
+}
+
+func (fn *CompleteCombiner2) MergeAccumulators(i1 int, i2 int) int {
+	return i1 + i2
+}
+
+func (fn *CompleteCombiner2) ExtractOutput(i int) CustomType {
+	return CustomType{val: i}
+}
+
+type CompleteCombiner1 struct {
+}
+
+func (fn *CompleteCombiner1) CreateAccumulator() int {
+	return 0
+}
+
+func (fn *CompleteCombiner1) AddInput(i1 int, i2 int) int {
+	return i1 + i2
+}
+
+func (fn *CompleteCombiner1) MergeAccumulators(i1 int, i2 int) int {
+	return i1 + i2
+}
+
+func (fn *CompleteCombiner1) ExtractOutput(i int) int {
+	return i
+}
+
+type PartialCombiner2 struct {
+}
+
+func (fn *PartialCombiner2) CreateAccumulator() int {
+	return 0
+}
+
+func (fn *PartialCombiner2) AddInput(i int, c CustomType) int {
+	return i + c.val
+}
+
+func (fn *PartialCombiner2) MergeAccumulators(i1 int, i2 int) int {
+	return i1 + i2
+}