You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/06/01 17:52:18 UTC
[beam] branch master updated: [BEAM-14470] Use lifecycle method names directly. (#17790)
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4bb3970f375 [BEAM-14470] Use lifecycle method names directly. (#17790)
4bb3970f375 is described below
commit 4bb3970f37548dc84ddad47aea7a1a4a1a3fda15
Author: Robert Burke <lo...@users.noreply.github.com>
AuthorDate: Wed Jun 1 10:52:11 2022 -0700
[BEAM-14470] Use lifecycle method names directly. (#17790)
---
sdks/go/pkg/beam/core/graph/fn.go | 47 ++------
sdks/go/pkg/beam/core/graph/fn_test.go | 208 +++++++++++++++++++++++++++++----
2 files changed, 197 insertions(+), 58 deletions(-)
diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 5eea802d5fa..6b3656c1c25 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -123,32 +123,23 @@ func NewFn(fn interface{}) (*Fn, error) {
methods[name] = f
}
}
- // TODO(lostluck): Consider moving this into the reflectx package.
- for i := 0; i < val.Type().NumMethod(); i++ {
- m := val.Type().Method(i)
- if m.PkgPath != "" {
- continue // skip: unexported
- }
- if m.Name == "String" {
- continue // skip: harmless
- }
- if _, ok := methods[m.Name]; ok {
+ for mName := range lifecycleMethods {
+ if _, ok := methods[mName]; ok {
continue // skip : already wrapped
}
+ m, ok := val.Type().MethodByName(mName)
+ if !ok {
+ continue // skip: doesn't exist
+ }
// CAVEAT(herohde) 5/22/2017: The type val.Type.Method.Type is not
// the same as val.Method.Type: the former has the explicit receiver.
// We'll use the receiver-less version.
-
- // TODO(herohde) 5/22/2017: Alternatively, it looks like we could
- // serialize each method, call them explicitly and avoid struct
- // registration.
-
- f, err := funcx.New(reflectx.MakeFunc(val.Method(i).Interface()))
+ f, err := funcx.New(reflectx.MakeFunc(val.Method(m.Index).Interface()))
if err != nil {
- return nil, errors.Wrapf(err, "method %v invalid", m.Name)
+ return nil, errors.Wrapf(err, "method %v invalid", mName)
}
- methods[m.Name] = f
+ methods[mName] = f
}
return &Fn{Recv: fn, methods: methods, annotations: annotations}, nil
@@ -450,9 +441,6 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
if fn.Fn != nil {
fn.methods[processElementName] = fn.Fn
}
- if err := verifyValidNames("graph.AsDoFn", fn, doFnNames...); err != nil {
- return nil, err
- }
if _, ok := fn.methods[processElementName]; !ok {
err := errors.Errorf("failed to find %v method", processElementName)
@@ -1295,9 +1283,6 @@ func AsCombineFn(fn *Fn) (*CombineFn, error) {
if fn.Fn != nil {
fn.methods[mergeAccumulatorsName] = fn.Fn
}
- if err := verifyValidNames(fnKind, fn, setupName, createAccumulatorName, addInputName, mergeAccumulatorsName, extractOutputName, compactName, teardownName); err != nil {
- return nil, err
- }
mergeFn, ok := fn.methods[mergeAccumulatorsName]
if !ok {
@@ -1356,20 +1341,6 @@ func validateSignature(fnKind, methodName string, fn *Fn, accumType reflect.Type
return nil
}
-func verifyValidNames(fnKind string, fn *Fn, names ...string) error {
- m := make(map[string]bool)
- for _, name := range names {
- m[name] = true
- }
-
- for key := range fn.methods {
- if !m[key] {
- return errors.Errorf("%s: unexpected exported method %v present on %v. Valid methods are: %v", fnKind, key, fn.Name(), names)
- }
- }
- return nil
-}
-
type verifyMethodError struct {
// Context for the error.
fnKind, methodName string
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index 0612f0ec4cb..a1702175f64 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -20,11 +20,13 @@ package graph
import (
"context"
"reflect"
+ "strings"
"testing"
"time"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)
func TestNewDoFn(t *testing.T) {
@@ -161,6 +163,7 @@ func TestNewDoFnSdf(t *testing.T) {
}{
{dfn: &GoodSdf{}, main: MainSingle},
{dfn: &GoodSdfKv{}, main: MainKv},
+ {dfn: &GoodIgnoreOtherExportedMethods{}, main: MainSingle},
}
for _, test := range tests {
@@ -205,7 +208,6 @@ func TestNewDoFnSdf(t *testing.T) {
{dfn: &BadSdfRestTCreateTracker{}},
{dfn: &BadSdfRestTTruncateRestriction{}},
// Validate other types
- {dfn: &BadSdfRestSizeReturn{}},
{dfn: &BadSdfCreateTrackerReturn{}},
{dfn: &BadSdfMismatchedRTracker{}},
{dfn: &BadSdfMissingRTracker{}},
@@ -321,6 +323,7 @@ func TestNewCombineFn(t *testing.T) {
{cfn: &GoodWErrorCombineFn{}},
{cfn: &GoodWContextCombineFn{}},
{cfn: &GoodCombineFnUnexportedExtraMethod{}},
+ {cfn: &GoodCombineFnExtraExportedMethod{}},
}
for _, test := range tests {
@@ -363,7 +366,6 @@ func TestNewCombineFn(t *testing.T) {
{cfn: &BadCombineFnInvalidExtractOutput1{}},
{cfn: &BadCombineFnInvalidExtractOutput2{}},
{cfn: &BadCombineFnInvalidExtractOutput3{}},
- {cfn: &BadCombineFnExtraExportedMethod{}},
}
for _, test := range tests {
t.Run(reflect.TypeOf(test.cfn).String(), func(t *testing.T) {
@@ -378,6 +380,166 @@ func TestNewCombineFn(t *testing.T) {
})
}
+func TestNewFn_DoFn(t *testing.T) {
+ // Validate wrap fallthrough
+ reflectx.RegisterStructWrapper(reflect.TypeOf((*GoodDoFn)(nil)).Elem(), func(fn interface{}) map[string]reflectx.Func {
+ gdf := fn.(*GoodDoFn)
+ return map[string]reflectx.Func{
+ processElementName: reflectx.MakeFunc1x1(func(v int) int {
+ return gdf.ProcessElement(v)
+ }),
+ }
+ })
+
+ userFn := &GoodDoFn{}
+ fn, err := NewFn(userFn)
+ if err != nil {
+ t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+ }
+ dofn, err := AsDoFn(fn, MainSingle)
+ if err != nil {
+ t.Errorf("AsDoFn(%v, MainSingle) failed:\n%v", fn.Name(), err)
+ }
+ // Check that we get expected values for all the methods.
+ if got, want := dofn.Name(), "GoodDoFn"; !strings.HasSuffix(got, want) {
+ t.Errorf("(%v).Name() = %q, want suffix %q", dofn.Name(), got, want)
+ }
+ if dofn.SetupFn() == nil {
+ t.Errorf("(%v).SetupFn() == nil, want value", dofn.Name())
+ }
+ if dofn.StartBundleFn() == nil {
+ t.Errorf("(%v).StartBundleFn() == nil, want value", dofn.Name())
+ }
+ if dofn.ProcessElementFn() == nil {
+ t.Errorf("(%v).ProcessElementFn() == nil, want value", dofn.Name())
+ }
+ if dofn.FinishBundleFn() == nil {
+ t.Errorf("(%v).FinishBundleFn() == nil, want value", dofn.Name())
+ }
+ if dofn.TeardownFn() == nil {
+ t.Errorf("(%v).TeardownFn() == nil, want value", dofn.Name())
+ }
+ if dofn.IsSplittable() {
+ t.Errorf("(%v).IsSplittable() = true, want false", dofn.Name())
+ }
+}
+
+func TestNewFn_SplittableDoFn(t *testing.T) {
+ userFn := &GoodStatefulWatermarkEstimating{}
+ fn, err := NewFn(userFn)
+ if err != nil {
+ t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+ }
+ dofn, err := AsDoFn(fn, MainSingle)
+ if err != nil {
+ t.Errorf("AsDoFn(%v, MainKv) failed:\n%v", fn.Name(), err)
+ }
+ // Check that we get expected values for all the methods.
+ if dofn.SetupFn() == nil {
+ t.Errorf("(%v).SetupFn() == nil, want value", dofn.Name())
+ }
+ if dofn.StartBundleFn() == nil {
+ t.Errorf("(%v).StartBundleFn() == nil, want value", dofn.Name())
+ }
+ if dofn.ProcessElementFn() == nil {
+ t.Errorf("(%v).ProcessElementFn() == nil, want value", dofn.Name())
+ }
+ if dofn.FinishBundleFn() == nil {
+ t.Errorf("(%v).FinishBundleFn() == nil, want value", dofn.Name())
+ }
+ if dofn.TeardownFn() == nil {
+ t.Errorf("(%v).TeardownFn() == nil, want value", dofn.Name())
+ }
+
+ if !dofn.IsSplittable() {
+ t.Fatalf("(%v).IsSplittable() = false, want true", dofn.Name())
+ }
+ sdofn := (*SplittableDoFn)(dofn)
+
+ if got, want := sdofn.Name(), "GoodStatefulWatermarkEstimating"; !strings.HasSuffix(got, want) {
+ t.Errorf("(%v).Name() = %q, want suffix %q", sdofn.Name(), got, want)
+ }
+ if sdofn.CreateInitialRestrictionFn() == nil {
+ t.Errorf("(%v).CreateInitialRestrictionFn() == nil, want value", sdofn.Name())
+ }
+ if sdofn.CreateTrackerFn() == nil {
+ t.Errorf("(%v).CreateTrackerFn() == nil, want value", sdofn.Name())
+ }
+ if sdofn.RestrictionSizeFn() == nil {
+ t.Errorf("(%v).RestrictionSizeFn() == nil, want value", sdofn.Name())
+ }
+ if got, want := sdofn.RestrictionT(), reflect.TypeOf(RestT{}); got != want {
+ t.Errorf("(%v).RestrictionT() == %v, want %v", sdofn.Name(), got, want)
+ }
+ if sdofn.SplitRestrictionFn() == nil {
+ t.Errorf("(%v).SplitRestrictionFn() == nil, want value", sdofn.Name())
+ }
+ if !sdofn.HasTruncateRestriction() {
+ t.Fatalf("(%v).HasTruncateRestriction() = false, want true", dofn.Name())
+ }
+ if sdofn.TruncateRestrictionFn() == nil {
+ t.Errorf("(%v).TruncateRestrictionFn() == nil, want value", sdofn.Name())
+ }
+ if !sdofn.IsWatermarkEstimating() {
+ t.Fatalf("(%v).IsWatermarkEstimating() = false, want true", dofn.Name())
+ }
+ if sdofn.CreateWatermarkEstimatorFn() == nil {
+ t.Errorf("(%v).CreateWatermarkEstimatorFn() == nil, want value", sdofn.Name())
+ }
+ if !sdofn.IsStatefulWatermarkEstimating() {
+ t.Fatalf("(%v).IsStatefulWatermarkEstimating() = false, want true", dofn.Name())
+ }
+ if sdofn.InitialWatermarkEstimatorStateFn() == nil {
+ t.Errorf("(%v).InitialWatermarkEstimatorStateFn() == nil, want value", sdofn.Name())
+ }
+ if sdofn.WatermarkEstimatorStateFn() == nil {
+ t.Errorf("(%v).WatermarkEstimatorStateFn() == nil, want value", sdofn.Name())
+ }
+ if got, want := sdofn.WatermarkEstimatorT(), reflect.TypeOf(&WatermarkEstimatorT{}); got != want {
+ t.Errorf("(%v).WatermarkEstimatorT() == %v, want %v", sdofn.Name(), got, want)
+ }
+ if got, want := sdofn.WatermarkEstimatorStateT(), reflectx.Int; got != want {
+ t.Errorf("(%v).WatermarkEstimatorT() == %v, want %v", sdofn.Name(), got, want)
+ }
+}
+
+func TestNewFn_CombineFn(t *testing.T) {
+ userFn := &GoodCombineFn{}
+ fn, err := NewFn(userFn)
+ if err != nil {
+ t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+ }
+ cfn, err := AsCombineFn(fn)
+ if err != nil {
+ t.Errorf("AsCombineFn(%v) failed:\n%v", fn.Name(), err)
+ }
+ // Check that we get expected values for all the methods.
+ if got, want := cfn.Name(), "GoodCombineFn"; !strings.HasSuffix(got, want) {
+ t.Errorf("(%v).Name() = %q, want suffix %q", cfn.Name(), got, want)
+ }
+ if cfn.SetupFn() == nil {
+ t.Errorf("(%v).SetupFn() == nil, want value", cfn.Name())
+ }
+ if cfn.CreateAccumulatorFn() == nil {
+ t.Errorf("(%v).CreateAccumulatorFn() == nil, want value", cfn.Name())
+ }
+ if cfn.AddInputFn() == nil {
+ t.Errorf("(%v).AddInputFn() == nil, want value", cfn.Name())
+ }
+ if cfn.MergeAccumulatorsFn() == nil {
+ t.Errorf("(%v).MergeAccumulatorsFn() == nil, want value", cfn.Name())
+ }
+ if cfn.ExtractOutputFn() == nil {
+ t.Errorf("(%v).ExtractOutputFn() == nil, want value", cfn.Name())
+ }
+ if cfn.CompactFn() == nil {
+ t.Errorf("(%v).CompactFn() == nil, want value", cfn.Name())
+ }
+ if cfn.TeardownFn() == nil {
+ t.Errorf("(%v).TeardownFn() == nil, want value", cfn.Name())
+ }
+}
+
// Do not copy. The following types are for testing signatures only.
// They are not working examples.
// Keep all test functions Above this point.
@@ -798,6 +960,14 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT {
return RestT{}
}
+type GoodIgnoreOtherExportedMethods struct {
+ *GoodSdf
+}
+
+func (fn *GoodIgnoreOtherExportedMethods) IgnoreOtherExportedMethods(int, RestT) int {
+ return 0
+}
+
type WatermarkEstimatorT struct{}
func (e *WatermarkEstimatorT) CurrentWatermark() time.Time {
@@ -1071,14 +1241,6 @@ func (fn *BadWatermarkEstimatingNonSdf) CreateWatermarkEstimator() *WatermarkEst
// Examples of other type validation that needs to be done.
-type BadSdfRestSizeReturn struct {
- *GoodSdf
-}
-
-func (fn *BadSdfRestSizeReturn) BadSdfRestSizeReturn(int, RestT) int {
- return 0
-}
-
type BadRTrackerT struct{} // Fails to implement RTracker interface.
type BadSdfCreateTrackerReturn struct {
@@ -1266,6 +1428,8 @@ type MyAccum struct{}
type GoodCombineFn struct{}
+func (fn *GoodCombineFn) Setup() {}
+
func (fn *GoodCombineFn) MergeAccumulators(MyAccum, MyAccum) MyAccum {
return MyAccum{}
}
@@ -1282,6 +1446,12 @@ func (fn *GoodCombineFn) ExtractOutput(MyAccum) int64 {
return 0
}
+func (fn *GoodCombineFn) Compact(MyAccum) MyAccum {
+ return MyAccum{}
+}
+
+func (fn *GoodCombineFn) Teardown() {}
+
type GoodWErrorCombineFn struct{}
func (fn *GoodWErrorCombineFn) MergeAccumulators(int, int) (int, error) {
@@ -1314,6 +1484,14 @@ func (fn *GoodCombineFnUnexportedExtraMethod) unexportedExtraMethod(context.Cont
return ""
}
+type GoodCombineFnExtraExportedMethod struct {
+ *GoodCombineFn
+}
+
+func (fn *GoodCombineFnExtraExportedMethod) ExtraMethod(string) int {
+ return 0
+}
+
// Examples of incorrect CombineFn signatures.
// Embedding *GoodCombineFn avoids repetitive MergeAccumulators signatures when desired.
// The immediately following examples are relating to accumulator mismatches.
@@ -1463,13 +1641,3 @@ type BadCombineFnInvalidExtractOutput3 struct {
func (fn *BadCombineFnInvalidExtractOutput3) ExtractOutput(context.Context, MyAccum, int) int {
return 0
}
-
-// Other CombineFn Errors
-
-type BadCombineFnExtraExportedMethod struct {
- *GoodCombineFn
-}
-
-func (fn *BadCombineFnExtraExportedMethod) ExtraMethod(string) int {
- return 0
-}