You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/06/01 17:52:18 UTC

[beam] branch master updated: [BEAM-14470] Use lifecycle method names directly. (#17790)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4bb3970f375 [BEAM-14470] Use lifecycle method names directly. (#17790)
4bb3970f375 is described below

commit 4bb3970f37548dc84ddad47aea7a1a4a1a3fda15
Author: Robert Burke <lo...@users.noreply.github.com>
AuthorDate: Wed Jun 1 10:52:11 2022 -0700

    [BEAM-14470] Use lifecycle method names directly. (#17790)
---
 sdks/go/pkg/beam/core/graph/fn.go      |  47 ++------
 sdks/go/pkg/beam/core/graph/fn_test.go | 208 +++++++++++++++++++++++++++++----
 2 files changed, 197 insertions(+), 58 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 5eea802d5fa..6b3656c1c25 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -123,32 +123,23 @@ func NewFn(fn interface{}) (*Fn, error) {
 				methods[name] = f
 			}
 		}
-		// TODO(lostluck): Consider moving this into the reflectx package.
-		for i := 0; i < val.Type().NumMethod(); i++ {
-			m := val.Type().Method(i)
-			if m.PkgPath != "" {
-				continue // skip: unexported
-			}
-			if m.Name == "String" {
-				continue // skip: harmless
-			}
-			if _, ok := methods[m.Name]; ok {
+		for mName := range lifecycleMethods {
+			if _, ok := methods[mName]; ok {
 				continue // skip : already wrapped
 			}
+			m, ok := val.Type().MethodByName(mName)
+			if !ok {
+				continue // skip: doesn't exist
+			}
 
 			// CAVEAT(herohde) 5/22/2017: The type val.Type.Method.Type is not
 			// the same as val.Method.Type: the former has the explicit receiver.
 			// We'll use the receiver-less version.
-
-			// TODO(herohde) 5/22/2017: Alternatively, it looks like we could
-			// serialize each method, call them explicitly and avoid struct
-			// registration.
-
-			f, err := funcx.New(reflectx.MakeFunc(val.Method(i).Interface()))
+			f, err := funcx.New(reflectx.MakeFunc(val.Method(m.Index).Interface()))
 			if err != nil {
-				return nil, errors.Wrapf(err, "method %v invalid", m.Name)
+				return nil, errors.Wrapf(err, "method %v invalid", mName)
 			}
-			methods[m.Name] = f
+			methods[mName] = f
 		}
 		return &Fn{Recv: fn, methods: methods, annotations: annotations}, nil
 
@@ -450,9 +441,6 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
 	if fn.Fn != nil {
 		fn.methods[processElementName] = fn.Fn
 	}
-	if err := verifyValidNames("graph.AsDoFn", fn, doFnNames...); err != nil {
-		return nil, err
-	}
 
 	if _, ok := fn.methods[processElementName]; !ok {
 		err := errors.Errorf("failed to find %v method", processElementName)
@@ -1295,9 +1283,6 @@ func AsCombineFn(fn *Fn) (*CombineFn, error) {
 	if fn.Fn != nil {
 		fn.methods[mergeAccumulatorsName] = fn.Fn
 	}
-	if err := verifyValidNames(fnKind, fn, setupName, createAccumulatorName, addInputName, mergeAccumulatorsName, extractOutputName, compactName, teardownName); err != nil {
-		return nil, err
-	}
 
 	mergeFn, ok := fn.methods[mergeAccumulatorsName]
 	if !ok {
@@ -1356,20 +1341,6 @@ func validateSignature(fnKind, methodName string, fn *Fn, accumType reflect.Type
 	return nil
 }
 
-func verifyValidNames(fnKind string, fn *Fn, names ...string) error {
-	m := make(map[string]bool)
-	for _, name := range names {
-		m[name] = true
-	}
-
-	for key := range fn.methods {
-		if !m[key] {
-			return errors.Errorf("%s: unexpected exported method %v present on %v. Valid methods are: %v", fnKind, key, fn.Name(), names)
-		}
-	}
-	return nil
-}
-
 type verifyMethodError struct {
 	// Context for the error.
 	fnKind, methodName string
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index 0612f0ec4cb..a1702175f64 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -20,11 +20,13 @@ package graph
 import (
 	"context"
 	"reflect"
+	"strings"
 	"testing"
 	"time"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
 )
 
 func TestNewDoFn(t *testing.T) {
@@ -161,6 +163,7 @@ func TestNewDoFnSdf(t *testing.T) {
 		}{
 			{dfn: &GoodSdf{}, main: MainSingle},
 			{dfn: &GoodSdfKv{}, main: MainKv},
+			{dfn: &GoodIgnoreOtherExportedMethods{}, main: MainSingle},
 		}
 
 		for _, test := range tests {
@@ -205,7 +208,6 @@ func TestNewDoFnSdf(t *testing.T) {
 			{dfn: &BadSdfRestTCreateTracker{}},
 			{dfn: &BadSdfRestTTruncateRestriction{}},
 			// Validate other types
-			{dfn: &BadSdfRestSizeReturn{}},
 			{dfn: &BadSdfCreateTrackerReturn{}},
 			{dfn: &BadSdfMismatchedRTracker{}},
 			{dfn: &BadSdfMissingRTracker{}},
@@ -321,6 +323,7 @@ func TestNewCombineFn(t *testing.T) {
 			{cfn: &GoodWErrorCombineFn{}},
 			{cfn: &GoodWContextCombineFn{}},
 			{cfn: &GoodCombineFnUnexportedExtraMethod{}},
+			{cfn: &GoodCombineFnExtraExportedMethod{}},
 		}
 
 		for _, test := range tests {
@@ -363,7 +366,6 @@ func TestNewCombineFn(t *testing.T) {
 			{cfn: &BadCombineFnInvalidExtractOutput1{}},
 			{cfn: &BadCombineFnInvalidExtractOutput2{}},
 			{cfn: &BadCombineFnInvalidExtractOutput3{}},
-			{cfn: &BadCombineFnExtraExportedMethod{}},
 		}
 		for _, test := range tests {
 			t.Run(reflect.TypeOf(test.cfn).String(), func(t *testing.T) {
@@ -378,6 +380,166 @@ func TestNewCombineFn(t *testing.T) {
 	})
 }
 
+func TestNewFn_DoFn(t *testing.T) {
+	// Validate wrap fallthrough
+	reflectx.RegisterStructWrapper(reflect.TypeOf((*GoodDoFn)(nil)).Elem(), func(fn interface{}) map[string]reflectx.Func {
+		gdf := fn.(*GoodDoFn)
+		return map[string]reflectx.Func{
+			processElementName: reflectx.MakeFunc1x1(func(v int) int {
+				return gdf.ProcessElement(v)
+			}),
+		}
+	})
+
+	userFn := &GoodDoFn{}
+	fn, err := NewFn(userFn)
+	if err != nil {
+		t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+	}
+	dofn, err := AsDoFn(fn, MainSingle)
+	if err != nil {
+		t.Errorf("AsDoFn(%v, MainSingle) failed:\n%v", fn.Name(), err)
+	}
+	// Check that we get expected values for all the methods.
+	if got, want := dofn.Name(), "GoodDoFn"; !strings.HasSuffix(got, want) {
+		t.Errorf("(%v).Name() = %q, want suffix %q", dofn.Name(), got, want)
+	}
+	if dofn.SetupFn() == nil {
+		t.Errorf("(%v).SetupFn() == nil, want value", dofn.Name())
+	}
+	if dofn.StartBundleFn() == nil {
+		t.Errorf("(%v).StartBundleFn() == nil, want value", dofn.Name())
+	}
+	if dofn.ProcessElementFn() == nil {
+		t.Errorf("(%v).ProcessElementFn() == nil, want value", dofn.Name())
+	}
+	if dofn.FinishBundleFn() == nil {
+		t.Errorf("(%v).FinishBundleFn() == nil, want value", dofn.Name())
+	}
+	if dofn.TeardownFn() == nil {
+		t.Errorf("(%v).TeardownFn() == nil, want value", dofn.Name())
+	}
+	if dofn.IsSplittable() {
+		t.Errorf("(%v).IsSplittable() = true, want false", dofn.Name())
+	}
+}
+
+func TestNewFn_SplittableDoFn(t *testing.T) {
+	userFn := &GoodStatefulWatermarkEstimating{}
+	fn, err := NewFn(userFn)
+	if err != nil {
+		t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+	}
+	dofn, err := AsDoFn(fn, MainSingle)
+	if err != nil {
+		t.Errorf("AsDoFn(%v, MainKv) failed:\n%v", fn.Name(), err)
+	}
+	// Check that we get expected values for all the methods.
+	if dofn.SetupFn() == nil {
+		t.Errorf("(%v).SetupFn() == nil, want value", dofn.Name())
+	}
+	if dofn.StartBundleFn() == nil {
+		t.Errorf("(%v).StartBundleFn() == nil, want value", dofn.Name())
+	}
+	if dofn.ProcessElementFn() == nil {
+		t.Errorf("(%v).ProcessElementFn() == nil, want value", dofn.Name())
+	}
+	if dofn.FinishBundleFn() == nil {
+		t.Errorf("(%v).FinishBundleFn() == nil, want value", dofn.Name())
+	}
+	if dofn.TeardownFn() == nil {
+		t.Errorf("(%v).TeardownFn() == nil, want value", dofn.Name())
+	}
+
+	if !dofn.IsSplittable() {
+		t.Fatalf("(%v).IsSplittable() = false, want true", dofn.Name())
+	}
+	sdofn := (*SplittableDoFn)(dofn)
+
+	if got, want := sdofn.Name(), "GoodStatefulWatermarkEstimating"; !strings.HasSuffix(got, want) {
+		t.Errorf("(%v).Name() = %q, want suffix %q", sdofn.Name(), got, want)
+	}
+	if sdofn.CreateInitialRestrictionFn() == nil {
+		t.Errorf("(%v).CreateInitialRestrictionFn() == nil, want value", sdofn.Name())
+	}
+	if sdofn.CreateTrackerFn() == nil {
+		t.Errorf("(%v).CreateTrackerFn() == nil, want value", sdofn.Name())
+	}
+	if sdofn.RestrictionSizeFn() == nil {
+		t.Errorf("(%v).RestrictionSizeFn() == nil, want value", sdofn.Name())
+	}
+	if got, want := sdofn.RestrictionT(), reflect.TypeOf(RestT{}); got != want {
+		t.Errorf("(%v).RestrictionT() == %v, want %v", sdofn.Name(), got, want)
+	}
+	if sdofn.SplitRestrictionFn() == nil {
+		t.Errorf("(%v).SplitRestrictionFn() == nil, want value", sdofn.Name())
+	}
+	if !sdofn.HasTruncateRestriction() {
+		t.Fatalf("(%v).HasTruncateRestriction() = false, want true", dofn.Name())
+	}
+	if sdofn.TruncateRestrictionFn() == nil {
+		t.Errorf("(%v).TruncateRestrictionFn() == nil, want value", sdofn.Name())
+	}
+	if !sdofn.IsWatermarkEstimating() {
+		t.Fatalf("(%v).IsWatermarkEstimating() = false, want true", dofn.Name())
+	}
+	if sdofn.CreateWatermarkEstimatorFn() == nil {
+		t.Errorf("(%v).CreateWatermarkEstimatorFn() == nil, want value", sdofn.Name())
+	}
+	if !sdofn.IsStatefulWatermarkEstimating() {
+		t.Fatalf("(%v).IsStatefulWatermarkEstimating() = false, want true", dofn.Name())
+	}
+	if sdofn.InitialWatermarkEstimatorStateFn() == nil {
+		t.Errorf("(%v).InitialWatermarkEstimatorStateFn() == nil, want value", sdofn.Name())
+	}
+	if sdofn.WatermarkEstimatorStateFn() == nil {
+		t.Errorf("(%v).WatermarkEstimatorStateFn() == nil, want value", sdofn.Name())
+	}
+	if got, want := sdofn.WatermarkEstimatorT(), reflect.TypeOf(&WatermarkEstimatorT{}); got != want {
+		t.Errorf("(%v).WatermarkEstimatorT() == %v, want %v", sdofn.Name(), got, want)
+	}
+	if got, want := sdofn.WatermarkEstimatorStateT(), reflectx.Int; got != want {
+		t.Errorf("(%v).WatermarkEstimatorT() == %v, want %v", sdofn.Name(), got, want)
+	}
+}
+
+func TestNewFn_CombineFn(t *testing.T) {
+	userFn := &GoodCombineFn{}
+	fn, err := NewFn(userFn)
+	if err != nil {
+		t.Errorf("NewFn(%T) failed:\n%v", userFn, err)
+	}
+	cfn, err := AsCombineFn(fn)
+	if err != nil {
+		t.Errorf("AsCombineFn(%v) failed:\n%v", fn.Name(), err)
+	}
+	// Check that we get expected values for all the methods.
+	if got, want := cfn.Name(), "GoodCombineFn"; !strings.HasSuffix(got, want) {
+		t.Errorf("(%v).Name() = %q, want suffix %q", cfn.Name(), got, want)
+	}
+	if cfn.SetupFn() == nil {
+		t.Errorf("(%v).SetupFn() == nil, want value", cfn.Name())
+	}
+	if cfn.CreateAccumulatorFn() == nil {
+		t.Errorf("(%v).CreateAccumulatorFn() == nil, want value", cfn.Name())
+	}
+	if cfn.AddInputFn() == nil {
+		t.Errorf("(%v).AddInputFn() == nil, want value", cfn.Name())
+	}
+	if cfn.MergeAccumulatorsFn() == nil {
+		t.Errorf("(%v).MergeAccumulatorsFn() == nil, want value", cfn.Name())
+	}
+	if cfn.ExtractOutputFn() == nil {
+		t.Errorf("(%v).ExtractOutputFn() == nil, want value", cfn.Name())
+	}
+	if cfn.CompactFn() == nil {
+		t.Errorf("(%v).CompactFn() == nil, want value", cfn.Name())
+	}
+	if cfn.TeardownFn() == nil {
+		t.Errorf("(%v).TeardownFn() == nil, want value", cfn.Name())
+	}
+}
+
 // Do not copy. The following types are for testing signatures only.
 // They are not working examples.
 // Keep all test functions Above this point.
@@ -798,6 +960,14 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT {
 	return RestT{}
 }
 
+type GoodIgnoreOtherExportedMethods struct {
+	*GoodSdf
+}
+
+func (fn *GoodIgnoreOtherExportedMethods) IgnoreOtherExportedMethods(int, RestT) int {
+	return 0
+}
+
 type WatermarkEstimatorT struct{}
 
 func (e *WatermarkEstimatorT) CurrentWatermark() time.Time {
@@ -1071,14 +1241,6 @@ func (fn *BadWatermarkEstimatingNonSdf) CreateWatermarkEstimator() *WatermarkEst
 
 // Examples of other type validation that needs to be done.
 
-type BadSdfRestSizeReturn struct {
-	*GoodSdf
-}
-
-func (fn *BadSdfRestSizeReturn) BadSdfRestSizeReturn(int, RestT) int {
-	return 0
-}
-
 type BadRTrackerT struct{} // Fails to implement RTracker interface.
 
 type BadSdfCreateTrackerReturn struct {
@@ -1266,6 +1428,8 @@ type MyAccum struct{}
 
 type GoodCombineFn struct{}
 
+func (fn *GoodCombineFn) Setup() {}
+
 func (fn *GoodCombineFn) MergeAccumulators(MyAccum, MyAccum) MyAccum {
 	return MyAccum{}
 }
@@ -1282,6 +1446,12 @@ func (fn *GoodCombineFn) ExtractOutput(MyAccum) int64 {
 	return 0
 }
 
+func (fn *GoodCombineFn) Compact(MyAccum) MyAccum {
+	return MyAccum{}
+}
+
+func (fn *GoodCombineFn) Teardown() {}
+
 type GoodWErrorCombineFn struct{}
 
 func (fn *GoodWErrorCombineFn) MergeAccumulators(int, int) (int, error) {
@@ -1314,6 +1484,14 @@ func (fn *GoodCombineFnUnexportedExtraMethod) unexportedExtraMethod(context.Cont
 	return ""
 }
 
+type GoodCombineFnExtraExportedMethod struct {
+	*GoodCombineFn
+}
+
+func (fn *GoodCombineFnExtraExportedMethod) ExtraMethod(string) int {
+	return 0
+}
+
 // Examples of incorrect CombineFn signatures.
 // Embedding *GoodCombineFn avoids repetitive MergeAccumulators signatures when desired.
 // The immediately following examples are relating to accumulator mismatches.
@@ -1463,13 +1641,3 @@ type BadCombineFnInvalidExtractOutput3 struct {
 func (fn *BadCombineFnInvalidExtractOutput3) ExtractOutput(context.Context, MyAccum, int) int {
 	return 0
 }
-
-// Other CombineFn Errors
-
-type BadCombineFnExtraExportedMethod struct {
-	*GoodCombineFn
-}
-
-func (fn *BadCombineFnExtraExportedMethod) ExtraMethod(string) int {
-	return 0
-}