This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a5fbc8ed163 [BEAM-14347] Add generic registration for accumulators 
(#17579)
a5fbc8ed163 is described below

commit a5fbc8ed163339879547a8e16e4f4f99a5b3a388
Author: Danny McCormick <[email protected]>
AuthorDate: Mon May 9 20:05:24 2022 -0400

    [BEAM-14347] Add generic registration for accumulators (#17579)
---
 sdks/go/pkg/beam/registration/registration.go      | 655 ++++++++++++++++++++-
 sdks/go/pkg/beam/registration/registration.tmpl    | 310 +++++++++-
 sdks/go/pkg/beam/registration/registration_test.go | 234 ++++++++
 3 files changed, 1197 insertions(+), 2 deletions(-)

diff --git a/sdks/go/pkg/beam/registration/registration.go 
b/sdks/go/pkg/beam/registration/registration.go
index ba1585d71bf..95057a428ae 100644
--- a/sdks/go/pkg/beam/registration/registration.go
+++ b/sdks/go/pkg/beam/registration/registration.go
@@ -21,6 +21,7 @@ package registration
 
 import (
        "context"
+       "fmt"
        "reflect"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
@@ -7030,13 +7031,665 @@ type teardown1x1 interface {
        Teardown(ctx context.Context) error
 }
 
+type createAccumulator0x1[T any] interface {
+       CreateAccumulator() T
+}
+
+type createAccumulator0x2[T any] interface {
+       CreateAccumulator() (T, error)
+}
+
+type addInput2x1[T1, T2 any] interface {
+       AddInput(a T1, i T2) T1
+}
+
+type addInput2x2[T1, T2 any] interface {
+       AddInput(a T1, i T2) (T1, error)
+}
+
+type mergeAccumulators2x1[T any] interface {
+       MergeAccumulators(a0 T, a1 T) T
+}
+
+type mergeAccumulators2x2[T any] interface {
+       MergeAccumulators(a0 T, a1 T) (T, error)
+}
+
+type extractOutput1x1[T1, T2 any] interface {
+       ExtractOutput(a T1) T2
+}
+
+type extractOutput1x2[T1, T2 any] interface {
+       ExtractOutput(a T1) (T2, error)
+}
+
+// Combiner1 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different 
Combiner
+// functions, each of which should be used for a different situation.
+// Combiner1 should be used when your accumulator, input, and output are all 
of the same type.
+// It can be called with register.Combiner1[T](&CustomCombiner{})
+// where T is the type of the input/accumulator/output.
+func Combiner1[T0 any](accum interface{}) {
+       registerCombinerTypes(accum)
+       accumVal := reflect.ValueOf(accum)
+       var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) (T0, error))
+                       return &caller2x2[T0, T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+               mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return 
fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) T0)
+                       return &caller2x1[T0, T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+               mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return 
fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+       }
+
+       if mergeAccumulatorsWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize MergeAccumulators for 
combiner %v. Failed to infer types", accum))
+       }
+
+       var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(createAccumulator0x2[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func() (T0, error))
+                       return &caller0x2[T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, 
error))(nil)).Elem(), caller)
+
+               createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() (T0, error) {
+                               return 
fn.(createAccumulator0x2[T0]).CreateAccumulator()
+                       })
+               }
+       } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func() T0)
+                       return &caller0x1[T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), 
caller)
+
+               createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() T0 {
+                               return 
fn.(createAccumulator0x1[T0]).CreateAccumulator()
+                       })
+               }
+       }
+       if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && 
createAccumulatorWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize CreateAccumulator for 
combiner %v. Failed to infer types", accum))
+       }
+
+       var addInputWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(addInput2x2[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) (T0, error))
+                       return &caller2x2[T0, T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) T0)
+                       return &caller2x1[T0, T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+       }
+
+       if m := accumVal.MethodByName("AddInput"); m.IsValid() && 
addInputWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. 
Failed to infer types", accum))
+       }
+
+       var extractOutputWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) (T0, error))
+                       return &caller1x2[T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, 
error))(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+                               return fn.(extractOutput1x2[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) T0)
+                       return &caller1x1[T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T0)(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T0 {
+                               return fn.(extractOutput1x1[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+       }
+
+       if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && 
extractOutputWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize ExtractOutput for 
combiner %v. Failed to infer types", accum))
+       }
+
+       wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+               m := map[string]reflectx.Func{}
+               if mergeAccumulatorsWrapper != nil {
+                       m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+               }
+               if createAccumulatorWrapper != nil {
+                       m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+               }
+               if addInputWrapper != nil {
+                       m["AddInput"] = addInputWrapper(fn)
+               }
+               if extractOutputWrapper != nil {
+                       m["ExtractOutput"] = extractOutputWrapper(fn)
+               }
+
+               return m
+       }
+       reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+// Combiner2 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different 
Combiner
+// functions, each of which should be used for a different situation.
+// Combiner2 should be used when your accumulator, input, and output are 2 
distinct types.
+// It can be called with register.Combiner2[T1, T2](&CustomCombiner{})
+// where T1 is the type of the accumulator and T2 is the other type.
+func Combiner2[T0, T1 any](accum interface{}) {
+       registerCombinerTypes(accum)
+       accumVal := reflect.ValueOf(accum)
+       var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) (T0, error))
+                       return &caller2x2[T0, T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+               mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return 
fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) T0)
+                       return &caller2x1[T0, T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+               mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return 
fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+       }
+
+       if mergeAccumulatorsWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize MergeAccumulators for 
combiner %v. Failed to infer types", accum))
+       }
+
+       var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(createAccumulator0x2[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func() (T0, error))
+                       return &caller0x2[T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, 
error))(nil)).Elem(), caller)
+
+               createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() (T0, error) {
+                               return 
fn.(createAccumulator0x2[T0]).CreateAccumulator()
+                       })
+               }
+       } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func() T0)
+                       return &caller0x1[T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), 
caller)
+
+               createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() T0 {
+                               return 
fn.(createAccumulator0x1[T0]).CreateAccumulator()
+                       })
+               }
+       }
+       if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && 
createAccumulatorWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize CreateAccumulator for 
combiner %v. Failed to infer types", accum))
+       }
+
+       var addInputWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(addInput2x2[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) (T0, error))
+                       return &caller2x2[T0, T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) T0)
+                       return &caller2x1[T0, T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T1) (T0, error))
+                       return &caller2x2[T0, T1, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, 
error))(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T1) T0)
+                       return &caller2x1[T0, T1, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) 
T0)(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+                               return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+                       })
+               }
+       }
+
+       if m := accumVal.MethodByName("AddInput"); m.IsValid() && 
addInputWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. 
Failed to infer types", accum))
+       }
+
+       var extractOutputWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) (T0, error))
+                       return &caller1x2[T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, 
error))(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+                               return fn.(extractOutput1x2[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) T0)
+                       return &caller1x1[T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T0)(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T0 {
+                               return fn.(extractOutput1x1[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) (T1, error))
+                       return &caller1x2[T0, T1, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, 
error))(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+                               return fn.(extractOutput1x2[T0, 
T1]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) T1)
+                       return &caller1x1[T0, T1]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T1)(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T1 {
+                               return fn.(extractOutput1x1[T0, 
T1]).ExtractOutput(a0)
+                       })
+               }
+       }
+
+       if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && 
extractOutputWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize ExtractOutput for 
combiner %v. Failed to infer types", accum))
+       }
+
+       wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+               m := map[string]reflectx.Func{}
+               if mergeAccumulatorsWrapper != nil {
+                       m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+               }
+               if createAccumulatorWrapper != nil {
+                       m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+               }
+               if addInputWrapper != nil {
+                       m["AddInput"] = addInputWrapper(fn)
+               }
+               if extractOutputWrapper != nil {
+                       m["ExtractOutput"] = extractOutputWrapper(fn)
+               }
+
+               return m
+       }
+       reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+// Combiner3 registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different 
Combiner
+// functions, each of which should be used for a different situation.
+// Combiner3 should be used when your accumulator, input, and output are 3 
distinct types.
+// It can be called with register.Combiner3[T1, T2, T3](&CustomCombiner{})
+// where T1 is the type of the accumulator, T2 is the type of the input, and 
T3 is the type of the output.
+func Combiner3[T0, T1, T2 any](accum interface{}) {
+       registerCombinerTypes(accum)
+       accumVal := reflect.ValueOf(accum)
+       var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) (T0, error))
+                       return &caller2x2[T0, T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+               mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return 
fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) T0)
+                       return &caller2x1[T0, T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+               mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return 
fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+       }
+
+       if mergeAccumulatorsWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize MergeAccumulators for 
combiner %v. Failed to infer types", accum))
+       }
+
+       var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(createAccumulator0x2[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func() (T0, error))
+                       return &caller0x2[T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, 
error))(nil)).Elem(), caller)
+
+               createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() (T0, error) {
+                               return 
fn.(createAccumulator0x2[T0]).CreateAccumulator()
+                       })
+               }
+       } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func() T0)
+                       return &caller0x1[T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), 
caller)
+
+               createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() T0 {
+                               return 
fn.(createAccumulator0x1[T0]).CreateAccumulator()
+                       })
+               }
+       }
+       if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && 
createAccumulatorWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize CreateAccumulator for 
combiner %v. Failed to infer types", accum))
+       }
+
+       var addInputWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(addInput2x2[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) (T0, error))
+                       return &caller2x2[T0, T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T0) T0)
+                       return &caller2x1[T0, T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x2[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T1) (T0, error))
+                       return &caller2x2[T0, T1, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, 
error))(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T1) T0)
+                       return &caller2x1[T0, T1, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) 
T0)(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+                               return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x2[T0, T2]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T2) (T0, error))
+                       return &caller2x2[T0, T2, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) (T0, 
error))(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T2) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T2]).AddInput(a0, a1)
+                       })
+               }
+       } else if _, ok := accum.(addInput2x1[T0, T2]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0, T2) T0)
+                       return &caller2x1[T0, T2, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) 
T0)(nil)).Elem(), caller)
+
+               addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T2) T0 {
+                               return fn.(addInput2x1[T0, T2]).AddInput(a0, a1)
+                       })
+               }
+       }
+
+       if m := accumVal.MethodByName("AddInput"); m.IsValid() && 
addInputWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. 
Failed to infer types", accum))
+       }
+
+       var extractOutputWrapper func(fn interface{}) reflectx.Func
+       if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) (T0, error))
+                       return &caller1x2[T0, T0, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, 
error))(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+                               return fn.(extractOutput1x2[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) T0)
+                       return &caller1x1[T0, T0]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T0)(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T0 {
+                               return fn.(extractOutput1x1[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x2[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) (T1, error))
+                       return &caller1x2[T0, T1, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, 
error))(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+                               return fn.(extractOutput1x2[T0, 
T1]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) T1)
+                       return &caller1x1[T0, T1]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T1)(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T1 {
+                               return fn.(extractOutput1x1[T0, 
T1]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x2[T0, T2]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) (T2, error))
+                       return &caller1x2[T0, T2, error]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T2, 
error))(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T2, error) {
+                               return fn.(extractOutput1x2[T0, 
T2]).ExtractOutput(a0)
+                       })
+               }
+       } else if _, ok := accum.(extractOutput1x1[T0, T2]); ok {
+               caller := func(fn interface{}) reflectx.Func {
+                       f := fn.(func(T0) T2)
+                       return &caller1x1[T0, T2]{fn: f}
+               }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T2)(nil)).Elem(), caller)
+
+               extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T2 {
+                               return fn.(extractOutput1x1[T0, 
T2]).ExtractOutput(a0)
+                       })
+               }
+       }
+
+       if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && 
extractOutputWrapper == nil {
+               panic(fmt.Sprintf("Failed to optimize ExtractOutput for 
combiner %v. Failed to infer types", accum))
+       }
+
+       wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+               m := map[string]reflectx.Func{}
+               if mergeAccumulatorsWrapper != nil {
+                       m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+               }
+               if createAccumulatorWrapper != nil {
+                       m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+               }
+               if addInputWrapper != nil {
+                       m["AddInput"] = addInputWrapper(fn)
+               }
+               if extractOutputWrapper != nil {
+                       m["ExtractOutput"] = extractOutputWrapper(fn)
+               }
+
+               return m
+       }
+       reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}
+
+func registerCombinerTypes(accum interface{}) {
+       // Register the combiner
+       runtime.RegisterType(reflect.TypeOf(accum).Elem())
+       schema.RegisterType(reflect.TypeOf(accum).Elem())
+
+       // Register all types in the Combiner.
+       // There may be different types across MergeAccumulators, AddInput, and 
ExtractOutput.
+       accumVal := reflect.ValueOf(accum)
+       registerMethodTypes(accumVal.MethodByName("MergeAccumulators").Type())
+       if m := accumVal.MethodByName("AddInput"); m.IsValid() {
+               registerMethodTypes(m.Type())
+       }
+       if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() {
+               registerMethodTypes(m.Type())
+       }
+}
+
 func registerDoFnTypes(doFn interface{}) {
        // Register the doFn
        runtime.RegisterType(reflect.TypeOf(doFn).Elem())
        schema.RegisterType(reflect.TypeOf(doFn).Elem())
 
        // Register all types in the DoFn
-       fn := reflect.ValueOf(doFn).MethodByName("ProcessElement").Type()
+       
registerMethodTypes(reflect.ValueOf(doFn).MethodByName("ProcessElement").Type())
+}
+
+func registerMethodTypes(fn reflect.Type) {
        for i := 0; i < fn.NumIn(); i++ {
                in := reflectx.SkipPtr(fn.In(i))
                if in.Kind() == reflect.Struct {
diff --git a/sdks/go/pkg/beam/registration/registration.tmpl 
b/sdks/go/pkg/beam/registration/registration.tmpl
index e046a9d165f..f73c0e323e2 100644
--- a/sdks/go/pkg/beam/registration/registration.tmpl
+++ b/sdks/go/pkg/beam/registration/registration.tmpl
@@ -118,6 +118,7 @@ package registration
 
 import (
        "context"
+    "fmt"
        "reflect"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
@@ -241,13 +242,320 @@ type teardown1x1 interface {
        Teardown(ctx context.Context) error
 }
 
+type createAccumulator0x1[T any] interface {
+    CreateAccumulator() T
+}
+
+type createAccumulator0x2[T any] interface {
+    CreateAccumulator() (T, error)
+}
+
+type addInput2x1[T1, T2 any] interface {
+    AddInput(a T1, i T2) T1
+}
+
+type addInput2x2[T1, T2 any] interface {
+    AddInput(a T1, i T2) (T1, error)
+}
+
+type mergeAccumulators2x1[T any] interface {
+    MergeAccumulators(a0 T, a1 T) T
+}
+
+type mergeAccumulators2x2[T any] interface {
+    MergeAccumulators(a0 T, a1 T) (T, error)
+}
+
+type extractOutput1x1[T1, T2 any] interface {
+    ExtractOutput(a T1) T2
+}
+
+type extractOutput1x2[T1, T2 any] interface {
+    ExtractOutput(a T1) (T2, error)
+}
+
+{{range $accum := upto 3}}{{$genericParams := (add $accum 1)}}
+// Combiner{{$genericParams}} registers a CombineFn's structural functions
+// and types and optimizes their runtime execution. There are 3 different 
Combiner
+// functions, each of which should be used for a different situation.
+{{if (eq $genericParams 1)}}// Combiner1 should be used when your accumulator, 
input, and output are all of the same type.
+// It can be called with register.Combiner1[T](&CustomCombiner{})
+// where T is the type of the input/accumulator/output.
+{{else}}{{if (eq $genericParams 2)}}// Combiner2 should be used when your 
accumulator, input, and output are 2 distinct types.
+// It can be called with register.Combiner2[T1, T2](&CustomCombiner{})
+// where T1 is the type of the accumulator and T2 is the other type.
+{{else}}// Combiner3 should be used when your accumulator, input, and output 
are 3 distinct types.
+// It can be called with register.Combiner3[T1, T2, T3](&CustomCombiner{})
+// where T1 is the type of the accumulator, T2 is the type of the input, and 
T3 is the type of the output.
+{{end}}{{end}}func Combiner{{$genericParams}}[{{range $paramNum := upto 
$genericParams}}{{if $paramNum}}, {{end}}T{{$paramNum}}{{end}} any](accum 
interface{}) {
+    registerCombinerTypes(accum)
+    accumVal := reflect.ValueOf(accum)
+    var mergeAccumulatorsWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(mergeAccumulators2x2[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) (T0, error))
+            return &caller2x2[T0, T0, T0, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+        mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return 
fn.(mergeAccumulators2x2[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+    } else if _, ok := accum.(mergeAccumulators2x1[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) T0)
+            return &caller2x1[T0, T0, T0]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+        mergeAccumulatorsWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return 
fn.(mergeAccumulators2x1[T0]).MergeAccumulators(a0, a1)
+                       })
+               }
+    }
+
+    if mergeAccumulatorsWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize MergeAccumulators for combiner 
%v. Failed to infer types", accum))
+    }
+
+    var createAccumulatorWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(createAccumulator0x2[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func() (T0, error))
+            return &caller0x2[T0, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() (T0, 
error))(nil)).Elem(), caller)
+
+        createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() (T0, error) {
+                               return 
fn.(createAccumulator0x2[T0]).CreateAccumulator()
+                       })
+               }
+    } else if _, ok := accum.(createAccumulator0x1[T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func() T0)
+            return &caller0x1[T0]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func() T0)(nil)).Elem(), 
caller)
+
+        createAccumulatorWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func() T0 {
+                               return 
fn.(createAccumulator0x1[T0]).CreateAccumulator()
+                       })
+               }
+    }
+    if m := accumVal.MethodByName("CreateAccumulator"); m.IsValid() && 
createAccumulatorWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize CreateAccumulator for combiner 
%v. Failed to infer types", accum))
+    }
+
+    var addInputWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(addInput2x2[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) (T0, error))
+            return &caller2x2[T0, T0, T0, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) (T0, 
error))(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+    } else if _, ok := accum.(addInput2x1[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T0) T0)
+            return &caller2x1[T0, T0, T0]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T0) 
T0)(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T0) T0 {
+                               return fn.(addInput2x1[T0, T0]).AddInput(a0, a1)
+                       })
+               }
+    } {{if (gt $genericParams 1)}} else if _, ok := accum.(addInput2x2[T0, 
T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T1) (T0, error))
+            return &caller2x2[T0, T1, T0, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) (T0, 
error))(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T1) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T1]).AddInput(a0, a1)
+                       })
+               }
+    } else if _, ok := accum.(addInput2x1[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T1) T0)
+            return &caller2x1[T0, T1, T0]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T1) 
T0)(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T1) T0 {
+                               return fn.(addInput2x1[T0, T1]).AddInput(a0, a1)
+                       })
+               }
+    } {{end}}{{if (gt $genericParams 2)}} else if _, ok := 
accum.(addInput2x2[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T2) (T0, error))
+            return &caller2x2[T0, T2, T0, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) (T0, 
error))(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T2) (T0, error) 
{
+                               return fn.(addInput2x2[T0, T2]).AddInput(a0, a1)
+                       })
+               }
+    } else if _, ok := accum.(addInput2x1[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0, T2) T0)
+            return &caller2x1[T0, T2, T0]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0, T2) 
T0)(nil)).Elem(), caller)
+
+        addInputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0, a1 T2) T0 {
+                               return fn.(addInput2x1[T0, T2]).AddInput(a0, a1)
+                       })
+               }
+    } {{end}}
+
+    if m := accumVal.MethodByName("AddInput"); m.IsValid() && addInputWrapper 
== nil {
+        panic(fmt.Sprintf("Failed to optimize AddInput for combiner %v. Failed 
to infer types", accum))
+    }
+
+    var extractOutputWrapper func(fn interface{}) reflectx.Func
+    if _, ok := accum.(extractOutput1x2[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) (T0, error))
+            return &caller1x2[T0, T0, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T0, 
error))(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T0, error) {
+                               return fn.(extractOutput1x2[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+    } else if _, ok := accum.(extractOutput1x1[T0, T0]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) T0)
+            return &caller1x1[T0, T0]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T0)(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T0 {
+                               return fn.(extractOutput1x1[T0, 
T0]).ExtractOutput(a0)
+                       })
+               }
+    } {{if (gt $genericParams 1)}} else if _, ok := 
accum.(extractOutput1x2[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) (T1, error))
+            return &caller1x2[T0, T1, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T1, 
error))(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T1, error) {
+                               return fn.(extractOutput1x2[T0, 
T1]).ExtractOutput(a0)
+                       })
+               }
+    } else if _, ok := accum.(extractOutput1x1[T0, T1]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) T1)
+            return &caller1x1[T0, T1]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T1)(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T1 {
+                               return fn.(extractOutput1x1[T0, 
T1]).ExtractOutput(a0)
+                       })
+               }
+    } {{end}}{{if (gt $genericParams 2)}} else if _, ok := 
accum.(extractOutput1x2[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) (T2, error))
+            return &caller1x2[T0, T2, error]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) (T2, 
error))(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) (T2, error) {
+                               return fn.(extractOutput1x2[T0, 
T2]).ExtractOutput(a0)
+                       })
+               }
+    } else if _, ok := accum.(extractOutput1x1[T0, T2]); ok {
+        caller := func(fn interface{}) reflectx.Func {
+            f := fn.(func(T0) T2)
+            return &caller1x1[T0, T2]{fn: f}
+        }
+               reflectx.RegisterFunc(reflect.TypeOf((*func(T0) 
T2)(nil)).Elem(), caller)
+
+        extractOutputWrapper = func(fn interface{}) reflectx.Func {
+                       return reflectx.MakeFunc(func(a0 T0) T2 {
+                               return fn.(extractOutput1x1[T0, 
T2]).ExtractOutput(a0)
+                       })
+               }
+    } {{end}}
+
+    if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() && 
extractOutputWrapper == nil {
+        panic(fmt.Sprintf("Failed to optimize ExtractOutput for combiner %v. 
Failed to infer types", accum))
+    }
+
+       wrapperFn := func(fn interface{}) map[string]reflectx.Func {
+               m := map[string]reflectx.Func{}
+               if mergeAccumulatorsWrapper != nil {
+                       m["MergeAccumulators"] = mergeAccumulatorsWrapper(fn)
+               }
+               if createAccumulatorWrapper != nil {
+                       m["CreateAccumulator"] = createAccumulatorWrapper(fn)
+               }
+               if addInputWrapper != nil {
+                       m["AddInput"] = addInputWrapper(fn)
+               }
+               if extractOutputWrapper != nil {
+                       m["ExtractOutput"] = extractOutputWrapper(fn)
+               }
+
+               return m
+       }
+       reflectx.RegisterStructWrapper(reflect.TypeOf(accum).Elem(), wrapperFn)
+}{{end}}
+
+func registerCombinerTypes(accum interface{}) {
+    // Register the combiner
+    runtime.RegisterType(reflect.TypeOf(accum).Elem())
+    schema.RegisterType(reflect.TypeOf(accum).Elem())
+
+    // Register all types in the Combiner.
+    // There may be different types across MergeAccumulators, AddInput, and 
ExtractOutput.
+    accumVal := reflect.ValueOf(accum)
+    registerMethodTypes(accumVal.MethodByName("MergeAccumulators").Type())
+    if m := accumVal.MethodByName("AddInput"); m.IsValid() {
+        registerMethodTypes(m.Type())
+    }
+    if m := accumVal.MethodByName("ExtractOutput"); m.IsValid() {
+        registerMethodTypes(m.Type())
+    }
+}
+
 func registerDoFnTypes(doFn interface{}) {
    // Register the doFn
    runtime.RegisterType(reflect.TypeOf(doFn).Elem())
    schema.RegisterType(reflect.TypeOf(doFn).Elem())
   
    // Register all types in the DoFn
-   fn := reflect.ValueOf(doFn).MethodByName("ProcessElement").Type()
+   
registerMethodTypes(reflect.ValueOf(doFn).MethodByName("ProcessElement").Type())
+}
+
+func registerMethodTypes(fn reflect.Type) {
    for i := 0; i < fn.NumIn(); i++ {
        in := reflectx.SkipPtr(fn.In(i))
        if in.Kind() == reflect.Struct {
diff --git a/sdks/go/pkg/beam/registration/registration_test.go 
b/sdks/go/pkg/beam/registration/registration_test.go
index 4f5c37dd2a8..2b535692ff2 100644
--- a/sdks/go/pkg/beam/registration/registration_test.go
+++ b/sdks/go/pkg/beam/registration/registration_test.go
@@ -159,6 +159,168 @@ func TestRegister_RegistersTypes(t *testing.T) {
        }
 }
 
+func TestCombiner_CompleteCombiner3(t *testing.T) {
+       accum := &CompleteCombiner3{}
+       Combiner3[int, CustomType, CustomType2](accum)
+
+       m, ok := reflectx.WrapMethods(&CompleteCombiner3{})
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no 
registered entry found")
+       }
+       ca, ok := m["CreateAccumulator"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no 
registered entry found for CreateAccumulator")
+       }
+       if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+               t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", 
got, want)
+       }
+       ai, ok := m["AddInput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no 
registered entry found for AddInput")
+       }
+       if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 
5; got != want {
+               t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+       }
+       ma, ok := m["MergeAccumulators"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+       eo, ok := m["ExtractOutput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner3{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := eo.Call([]interface{}{2})[0].(CustomType2).val2, 2; got 
!= want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+}
+
+func TestCombiner_RegistersTypes(t *testing.T) {
+       accum := &CompleteCombiner3{}
+       Combiner3[int, CustomType, CustomType2](accum)
+
+       // Need to call FromType so that the registry will reconcile its 
registrations
+       schema.FromType(reflect.TypeOf(accum).Elem())
+       if !schema.Registered(reflect.TypeOf(accum).Elem()) {
+               
t.Errorf("schema.Registered(reflect.TypeOf(CustomTypeDoFn1x1{})) = false, want 
true")
+       }
+       if !schema.Registered(reflect.TypeOf(CustomType{})) {
+               t.Errorf("schema.Registered(reflect.TypeOf(CustomType{})) = 
false, want true")
+       }
+       if !schema.Registered(reflect.TypeOf(CustomType2{})) {
+               t.Errorf("schema.Registered(reflect.TypeOf(CustomType{})) = 
false, want true")
+       }
+}
+
+func TestCombiner_CompleteCombiner2(t *testing.T) {
+       accum := &CompleteCombiner2{}
+       Combiner2[int, CustomType](accum)
+
+       m, ok := reflectx.WrapMethods(&CompleteCombiner2{})
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no 
registered entry found")
+       }
+       ca, ok := m["CreateAccumulator"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no 
registered entry found for CreateAccumulator")
+       }
+       if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+               t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", 
got, want)
+       }
+       ai, ok := m["AddInput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no 
registered entry found for AddInput")
+       }
+       if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 
5; got != want {
+               t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+       }
+       ma, ok := m["MergeAccumulators"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+       eo, ok := m["ExtractOutput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner2{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := eo.Call([]interface{}{2})[0].(CustomType).val, 2; got 
!= want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+}
+
+func TestCombiner_CompleteCombiner1(t *testing.T) {
+       accum := &CompleteCombiner1{}
+       Combiner1[int](accum)
+
+       m, ok := reflectx.WrapMethods(&CompleteCombiner1{})
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no 
registered entry found")
+       }
+       ca, ok := m["CreateAccumulator"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no 
registered entry found for CreateAccumulator")
+       }
+       if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+               t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", 
got, want)
+       }
+       ai, ok := m["AddInput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no 
registered entry found for AddInput")
+       }
+       if got, want := ai.Call([]interface{}{2, 3})[0].(int), 5; got != want {
+               t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+       }
+       ma, ok := m["MergeAccumulators"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+       eo, ok := m["ExtractOutput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&CompleteCombiner1{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := eo.Call([]interface{}{2})[0].(int), 2; got != want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+}
+
+func TestCombiner_PartialCombiner2(t *testing.T) {
+       accum := &PartialCombiner2{}
+       Combiner2[int, CustomType](accum)
+
+       m, ok := reflectx.WrapMethods(&PartialCombiner2{})
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no 
registered entry found")
+       }
+       ca, ok := m["CreateAccumulator"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no 
registered entry found for CreateAccumulator")
+       }
+       if got, want := ca.Call([]interface{}{})[0].(int), 0; got != want {
+               t.Errorf("Wrapped CreateAccumulator call: got %v, want %v", 
got, want)
+       }
+       ai, ok := m["AddInput"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no 
registered entry found for AddInput")
+       }
+       if got, want := ai.Call([]interface{}{2, CustomType{val: 3}})[0].(int), 
5; got != want {
+               t.Errorf("Wrapped AddInput call: got %v, want %v", got, want)
+       }
+       ma, ok := m["MergeAccumulators"]
+       if !ok {
+               t.Fatalf("reflectx.WrapMethods(&PartialCombiner2{}), no 
registered entry found for MergeAccumulators")
+       }
+       if got, want := ma.Call([]interface{}{2, 4})[0].(int), 6; got != want {
+               t.Errorf("Wrapped MergeAccumulators call: got %v, want %v", 
got, want)
+       }
+}
+
 func TestEmitter1(t *testing.T) {
        Emitter1[int]()
        if !exec.IsEmitterRegistered(reflect.TypeOf((*func(int))(nil)).Elem()) {
@@ -438,3 +600,75 @@ type CustomTypeDoFn1x1 struct {
 func (fn *CustomTypeDoFn1x1) ProcessElement(t CustomType) CustomType2 {
        return CustomType2{val2: t.val}
 }
+
+type CompleteCombiner3 struct {
+}
+
+func (fn *CompleteCombiner3) CreateAccumulator() int {
+       return 0
+}
+
+func (fn *CompleteCombiner3) AddInput(i int, c CustomType) int {
+       return i + c.val
+}
+
+func (fn *CompleteCombiner3) MergeAccumulators(i1 int, i2 int) int {
+       return i1 + i2
+}
+
+func (fn *CompleteCombiner3) ExtractOutput(i int) CustomType2 {
+       return CustomType2{val2: i}
+}
+
+type CompleteCombiner2 struct {
+}
+
+func (fn *CompleteCombiner2) CreateAccumulator() int {
+       return 0
+}
+
+func (fn *CompleteCombiner2) AddInput(i int, c CustomType) int {
+       return i + c.val
+}
+
+func (fn *CompleteCombiner2) MergeAccumulators(i1 int, i2 int) int {
+       return i1 + i2
+}
+
+func (fn *CompleteCombiner2) ExtractOutput(i int) CustomType {
+       return CustomType{val: i}
+}
+
+type CompleteCombiner1 struct {
+}
+
+func (fn *CompleteCombiner1) CreateAccumulator() int {
+       return 0
+}
+
+func (fn *CompleteCombiner1) AddInput(i1 int, i2 int) int {
+       return i1 + i2
+}
+
+func (fn *CompleteCombiner1) MergeAccumulators(i1 int, i2 int) int {
+       return i1 + i2
+}
+
+func (fn *CompleteCombiner1) ExtractOutput(i int) int {
+       return i
+}
+
+type PartialCombiner2 struct {
+}
+
+func (fn *PartialCombiner2) CreateAccumulator() int {
+       return 0
+}
+
+func (fn *PartialCombiner2) AddInput(i int, c CustomType) int {
+       return i + c.val
+}
+
+func (fn *PartialCombiner2) MergeAccumulators(i1 int, i2 int) int {
+       return i1 + i2
+}

Reply via email to