This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4bb3970f375 [BEAM-14470] Use lifecycle method names directly. (#17790)
4bb3970f375 is described below

commit 4bb3970f37548dc84ddad47aea7a1a4a1a3fda15
Author: Robert Burke <[email protected]>
AuthorDate: Wed Jun 1 10:52:11 2022 -0700

    [BEAM-14470] Use lifecycle method names directly. (#17790)
---
 sdks/go/pkg/beam/core/graph/fn.go      |  47 ++------
 sdks/go/pkg/beam/core/graph/fn_test.go | 208 +++++++++++++++++++++++++++++----
 2 files changed, 197 insertions(+), 58 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go 
b/sdks/go/pkg/beam/core/graph/fn.go
index 5eea802d5fa..6b3656c1c25 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -123,32 +123,23 @@ func NewFn(fn interface{}) (*Fn, error) {
                                methods[name] = f
                        }
                }
-               // TODO(lostluck): Consider moving this into the reflectx 
package.
-               for i := 0; i < val.Type().NumMethod(); i++ {
-                       m := val.Type().Method(i)
-                       if m.PkgPath != "" {
-                               continue // skip: unexported
-                       }
-                       if m.Name == "String" {
-                               continue // skip: harmless
-                       }
-                       if _, ok := methods[m.Name]; ok {
+               for mName := range lifecycleMethods {
+                       if _, ok := methods[mName]; ok {
                                continue // skip : already wrapped
                        }
+                       m, ok := val.Type().MethodByName(mName)
+                       if !ok {
+                               continue // skip: doesn't exist
+                       }
 
                        // CAVEAT(herohde) 5/22/2017: The type 
val.Type.Method.Type is not
                        // the same as val.Method.Type: the former has the 
explicit receiver.
                        // We'll use the receiver-less version.
-
-                       // TODO(herohde) 5/22/2017: Alternatively, it looks 
like we could
-                       // serialize each method, call them explicitly and 
avoid struct
-                       // registration.
-
-                       f, err := 
funcx.New(reflectx.MakeFunc(val.Method(i).Interface()))
+                       f, err := 
funcx.New(reflectx.MakeFunc(val.Method(m.Index).Interface()))
                        if err != nil {
-                               return nil, errors.Wrapf(err, "method %v 
invalid", m.Name)
+                               return nil, errors.Wrapf(err, "method %v 
invalid", mName)
                        }
-                       methods[m.Name] = f
+                       methods[mName] = f
                }
                return &Fn{Recv: fn, methods: methods, annotations: 
annotations}, nil
 
@@ -450,9 +441,6 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
        if fn.Fn != nil {
                fn.methods[processElementName] = fn.Fn
        }
-       if err := verifyValidNames("graph.AsDoFn", fn, doFnNames...); err != 
nil {
-               return nil, err
-       }
 
        if _, ok := fn.methods[processElementName]; !ok {
                err := errors.Errorf("failed to find %v method", 
processElementName)
@@ -1295,9 +1283,6 @@ func AsCombineFn(fn *Fn) (*CombineFn, error) {
        if fn.Fn != nil {
                fn.methods[mergeAccumulatorsName] = fn.Fn
        }
-       if err := verifyValidNames(fnKind, fn, setupName, 
createAccumulatorName, addInputName, mergeAccumulatorsName, extractOutputName, 
compactName, teardownName); err != nil {
-               return nil, err
-       }
 
        mergeFn, ok := fn.methods[mergeAccumulatorsName]
        if !ok {
@@ -1356,20 +1341,6 @@ func validateSignature(fnKind, methodName string, fn 
*Fn, accumType reflect.Type
        return nil
 }
 
-func verifyValidNames(fnKind string, fn *Fn, names ...string) error {
-       m := make(map[string]bool)
-       for _, name := range names {
-               m[name] = true
-       }
-
-       for key := range fn.methods {
-               if !m[key] {
-                       return errors.Errorf("%s: unexpected exported method %v 
present on %v. Valid methods are: %v", fnKind, key, fn.Name(), names)
-               }
-       }
-       return nil
-}
-
 type verifyMethodError struct {
        // Context for the error.
        fnKind, methodName string
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go 
b/sdks/go/pkg/beam/core/graph/fn_test.go
index 0612f0ec4cb..a1702175f64 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -20,11 +20,13 @@ package graph
 import (
        "context"
        "reflect"
+       "strings"
        "testing"
        "time"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
 )
 
 func TestNewDoFn(t *testing.T) {
@@ -161,6 +163,7 @@ func TestNewDoFnSdf(t *testing.T) {
                }{
                        {dfn: &GoodSdf{}, main: MainSingle},
                        {dfn: &GoodSdfKv{}, main: MainKv},
+                       {dfn: &GoodIgnoreOtherExportedMethods{}, main: 
MainSingle},
                }
 
                for _, test := range tests {
@@ -205,7 +208,6 @@ func TestNewDoFnSdf(t *testing.T) {
                        {dfn: &BadSdfRestTCreateTracker{}},
                        {dfn: &BadSdfRestTTruncateRestriction{}},
                        // Validate other types
-                       {dfn: &BadSdfRestSizeReturn{}},
                        {dfn: &BadSdfCreateTrackerReturn{}},
                        {dfn: &BadSdfMismatchedRTracker{}},
                        {dfn: &BadSdfMissingRTracker{}},
@@ -321,6 +323,7 @@ func TestNewCombineFn(t *testing.T) {
                        {cfn: &GoodWErrorCombineFn{}},
                        {cfn: &GoodWContextCombineFn{}},
                        {cfn: &GoodCombineFnUnexportedExtraMethod{}},
+                       {cfn: &GoodCombineFnExtraExportedMethod{}},
                }
 
                for _, test := range tests {
@@ -363,7 +366,6 @@ func TestNewCombineFn(t *testing.T) {
                        {cfn: &BadCombineFnInvalidExtractOutput1{}},
                        {cfn: &BadCombineFnInvalidExtractOutput2{}},
                        {cfn: &BadCombineFnInvalidExtractOutput3{}},
-                       {cfn: &BadCombineFnExtraExportedMethod{}},
                }
                for _, test := range tests {
                        t.Run(reflect.TypeOf(test.cfn).String(), func(t 
*testing.T) {
@@ -378,6 +380,166 @@ func TestNewCombineFn(t *testing.T) {
        })
 }
 
+func TestNewFn_DoFn(t *testing.T) {
+       // Validate wrap fallthrough
+       reflectx.RegisterStructWrapper(reflect.TypeOf((*GoodDoFn)(nil)).Elem(), 
func(fn interface{}) map[string]reflectx.Func {
+               gdf := fn.(*GoodDoFn)
+               return map[string]reflectx.Func{
+                       processElementName: reflectx.MakeFunc1x1(func(v int) 
int {
+                               return gdf.ProcessElement(v)
+                       }),
+               }
+       })
+
+       userFn := &GoodDoFn{}
+       fn, err := NewFn(userFn)
+       if err != nil {
+               t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+       }
+       dofn, err := AsDoFn(fn, MainSingle)
+       if err != nil {
+               t.Errorf("AsDoFn(%v, MainSingle) failed:\n%v", fn.Name(), err)
+       }
+       // Check that we get expected values for all the methods.
+       if got, want := dofn.Name(), "GoodDoFn"; !strings.HasSuffix(got, want) {
+               t.Errorf("(%v).Name() = %q, want suffix %q", dofn.Name(), got, 
want)
+       }
+       if dofn.SetupFn() == nil {
+               t.Errorf("(%v).SetupFn() == nil, want value", dofn.Name())
+       }
+       if dofn.StartBundleFn() == nil {
+               t.Errorf("(%v).StartBundleFn() == nil, want value", dofn.Name())
+       }
+       if dofn.ProcessElementFn() == nil {
+               t.Errorf("(%v).ProcessElementFn() == nil, want value", 
dofn.Name())
+       }
+       if dofn.FinishBundleFn() == nil {
+               t.Errorf("(%v).FinishBundleFn() == nil, want value", 
dofn.Name())
+       }
+       if dofn.TeardownFn() == nil {
+               t.Errorf("(%v).TeardownFn() == nil, want value", dofn.Name())
+       }
+       if dofn.IsSplittable() {
+               t.Errorf("(%v).IsSplittable() = true, want false", dofn.Name())
+       }
+}
+
+func TestNewFn_SplittableDoFn(t *testing.T) {
+       userFn := &GoodStatefulWatermarkEstimating{}
+       fn, err := NewFn(userFn)
+       if err != nil {
+               t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+       }
+       dofn, err := AsDoFn(fn, MainSingle)
+       if err != nil {
+               t.Errorf("AsDoFn(%v, MainKv) failed:\n%v", fn.Name(), err)
+       }
+       // Check that we get expected values for all the methods.
+       if dofn.SetupFn() == nil {
+               t.Errorf("(%v).SetupFn() == nil, want value", dofn.Name())
+       }
+       if dofn.StartBundleFn() == nil {
+               t.Errorf("(%v).StartBundleFn() == nil, want value", dofn.Name())
+       }
+       if dofn.ProcessElementFn() == nil {
+               t.Errorf("(%v).ProcessElementFn() == nil, want value", 
dofn.Name())
+       }
+       if dofn.FinishBundleFn() == nil {
+               t.Errorf("(%v).FinishBundleFn() == nil, want value", 
dofn.Name())
+       }
+       if dofn.TeardownFn() == nil {
+               t.Errorf("(%v).TeardownFn() == nil, want value", dofn.Name())
+       }
+
+       if !dofn.IsSplittable() {
+               t.Fatalf("(%v).IsSplittable() = false, want true", dofn.Name())
+       }
+       sdofn := (*SplittableDoFn)(dofn)
+
+       if got, want := sdofn.Name(), "GoodStatefulWatermarkEstimating"; 
!strings.HasSuffix(got, want) {
+               t.Errorf("(%v).Name() = %q, want suffix %q", sdofn.Name(), got, 
want)
+       }
+       if sdofn.CreateInitialRestrictionFn() == nil {
+               t.Errorf("(%v).CreateInitialRestrictionFn() == nil, want 
value", sdofn.Name())
+       }
+       if sdofn.CreateTrackerFn() == nil {
+               t.Errorf("(%v).CreateTrackerFn() == nil, want value", 
sdofn.Name())
+       }
+       if sdofn.RestrictionSizeFn() == nil {
+               t.Errorf("(%v).RestrictionSizeFn() == nil, want value", 
sdofn.Name())
+       }
+       if got, want := sdofn.RestrictionT(), reflect.TypeOf(RestT{}); got != 
want {
+               t.Errorf("(%v).RestrictionT() == %v, want %v", sdofn.Name(), 
got, want)
+       }
+       if sdofn.SplitRestrictionFn() == nil {
+               t.Errorf("(%v).SplitRestrictionFn() == nil, want value", 
sdofn.Name())
+       }
+       if !sdofn.HasTruncateRestriction() {
+               t.Fatalf("(%v).HasTruncateRestriction() = false, want true", 
dofn.Name())
+       }
+       if sdofn.TruncateRestrictionFn() == nil {
+               t.Errorf("(%v).TruncateRestrictionFn() == nil, want value", 
sdofn.Name())
+       }
+       if !sdofn.IsWatermarkEstimating() {
+               t.Fatalf("(%v).IsWatermarkEstimating() = false, want true", 
dofn.Name())
+       }
+       if sdofn.CreateWatermarkEstimatorFn() == nil {
+               t.Errorf("(%v).CreateWatermarkEstimatorFn() == nil, want 
value", sdofn.Name())
+       }
+       if !sdofn.IsStatefulWatermarkEstimating() {
+               t.Fatalf("(%v).IsStatefulWatermarkEstimating() = false, want 
true", dofn.Name())
+       }
+       if sdofn.InitialWatermarkEstimatorStateFn() == nil {
+               t.Errorf("(%v).InitialWatermarkEstimatorStateFn() == nil, want 
value", sdofn.Name())
+       }
+       if sdofn.WatermarkEstimatorStateFn() == nil {
+               t.Errorf("(%v).WatermarkEstimatorStateFn() == nil, want value", 
sdofn.Name())
+       }
+       if got, want := sdofn.WatermarkEstimatorT(), 
reflect.TypeOf(&WatermarkEstimatorT{}); got != want {
+               t.Errorf("(%v).WatermarkEstimatorT() == %v, want %v", 
sdofn.Name(), got, want)
+       }
+       if got, want := sdofn.WatermarkEstimatorStateT(), reflectx.Int; got != 
want {
+               t.Errorf("(%v).WatermarkEstimatorT() == %v, want %v", 
sdofn.Name(), got, want)
+       }
+}
+
+func TestNewFn_CombineFn(t *testing.T) {
+       userFn := &GoodCombineFn{}
+       fn, err := NewFn(userFn)
+       if err != nil {
+               t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+       }
+       cfn, err := AsCombineFn(fn)
+       if err != nil {
+               t.Errorf("AsCombineFn(%v) failed:\n%v", fn.Name(), err)
+       }
+       // Check that we get expected values for all the methods.
+       if got, want := cfn.Name(), "GoodCombineFn"; !strings.HasSuffix(got, 
want) {
+               t.Errorf("(%v).Name() = %q, want suffix %q", cfn.Name(), got, 
want)
+       }
+       if cfn.SetupFn() == nil {
+               t.Errorf("(%v).SetupFn() == nil, want value", cfn.Name())
+       }
+       if cfn.CreateAccumulatorFn() == nil {
+               t.Errorf("(%v).CreateAccumulatorFn() == nil, want value", 
cfn.Name())
+       }
+       if cfn.AddInputFn() == nil {
+               t.Errorf("(%v).AddInputFn() == nil, want value", cfn.Name())
+       }
+       if cfn.MergeAccumulatorsFn() == nil {
+               t.Errorf("(%v).MergeAccumulatorsFn() == nil, want value", 
cfn.Name())
+       }
+       if cfn.ExtractOutputFn() == nil {
+               t.Errorf("(%v).ExtractOutputFn() == nil, want value", 
cfn.Name())
+       }
+       if cfn.CompactFn() == nil {
+               t.Errorf("(%v).CompactFn() == nil, want value", cfn.Name())
+       }
+       if cfn.TeardownFn() == nil {
+               t.Errorf("(%v).TeardownFn() == nil, want value", cfn.Name())
+       }
+}
+
 // Do not copy. The following types are for testing signatures only.
 // They are not working examples.
 // Keep all test functions Above this point.
@@ -798,6 +960,14 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, 
int) RestT {
        return RestT{}
 }
 
+type GoodIgnoreOtherExportedMethods struct {
+       *GoodSdf
+}
+
+func (fn *GoodIgnoreOtherExportedMethods) IgnoreOtherExportedMethods(int, 
RestT) int {
+       return 0
+}
+
 type WatermarkEstimatorT struct{}
 
 func (e *WatermarkEstimatorT) CurrentWatermark() time.Time {
@@ -1071,14 +1241,6 @@ func (fn *BadWatermarkEstimatingNonSdf) 
CreateWatermarkEstimator() *WatermarkEst
 
 // Examples of other type validation that needs to be done.
 
-type BadSdfRestSizeReturn struct {
-       *GoodSdf
-}
-
-func (fn *BadSdfRestSizeReturn) BadSdfRestSizeReturn(int, RestT) int {
-       return 0
-}
-
 type BadRTrackerT struct{} // Fails to implement RTracker interface.
 
 type BadSdfCreateTrackerReturn struct {
@@ -1266,6 +1428,8 @@ type MyAccum struct{}
 
 type GoodCombineFn struct{}
 
+func (fn *GoodCombineFn) Setup() {}
+
 func (fn *GoodCombineFn) MergeAccumulators(MyAccum, MyAccum) MyAccum {
        return MyAccum{}
 }
@@ -1282,6 +1446,12 @@ func (fn *GoodCombineFn) ExtractOutput(MyAccum) int64 {
        return 0
 }
 
+func (fn *GoodCombineFn) Compact(MyAccum) MyAccum {
+       return MyAccum{}
+}
+
+func (fn *GoodCombineFn) Teardown() {}
+
 type GoodWErrorCombineFn struct{}
 
 func (fn *GoodWErrorCombineFn) MergeAccumulators(int, int) (int, error) {
@@ -1314,6 +1484,14 @@ func (fn *GoodCombineFnUnexportedExtraMethod) 
unexportedExtraMethod(context.Cont
        return ""
 }
 
+type GoodCombineFnExtraExportedMethod struct {
+       *GoodCombineFn
+}
+
+func (fn *GoodCombineFnExtraExportedMethod) ExtraMethod(string) int {
+       return 0
+}
+
 // Examples of incorrect CombineFn signatures.
 // Embedding *GoodCombineFn avoids repetitive MergeAccumulators signatures 
when desired.
 // The immediately following examples are relating to accumulator mismatches.
@@ -1463,13 +1641,3 @@ type BadCombineFnInvalidExtractOutput3 struct {
 func (fn *BadCombineFnInvalidExtractOutput3) ExtractOutput(context.Context, 
MyAccum, int) int {
        return 0
 }
-
-// Other CombineFn Errors
-
-type BadCombineFnExtraExportedMethod struct {
-       *GoodCombineFn
-}
-
-func (fn *BadCombineFnExtraExportedMethod) ExtraMethod(string) int {
-       return 0
-}

Reply via email to