You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2023/02/16 20:02:22 UTC
[beam] branch master updated: [Go SDK]: Allow SDF methods to have context param and error return value (#25437)
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 29ea6e0eb8a [Go SDK]: Allow SDF methods to have context param and error return value (#25437)
29ea6e0eb8a is described below
commit 29ea6e0eb8a2f2cf571d1a27799d125ca051008b
Author: Johanna Öjeling <51...@users.noreply.github.com>
AuthorDate: Thu Feb 16 21:02:08 2023 +0100
[Go SDK]: Allow SDF methods to have context param and error return value (#25437)
* Allow context param and error return value in SDF validation
* Use context param and error return value in SDF method invocation
* Run go fmt
* Clean up error messages from "fn reflect.methodValueCall"
* Validate return value count in a more correct way
---
sdks/go/pkg/beam/core/graph/fn.go | 108 ++++---
sdks/go/pkg/beam/core/graph/fn_test.go | 87 ++++++
sdks/go/pkg/beam/core/runtime/exec/datasource.go | 8 +-
.../pkg/beam/core/runtime/exec/datasource_test.go | 12 +-
sdks/go/pkg/beam/core/runtime/exec/sdf.go | 134 +++++---
sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go | 339 +++++++++------------
.../beam/core/runtime/exec/sdf_invokers_arity.go | 336 ++++++++++++++++++++
.../beam/core/runtime/exec/sdf_invokers_arity.tmpl | 246 +++++++++++++++
.../beam/core/runtime/exec/sdf_invokers_test.go | 250 +++++++++++++--
sdks/go/pkg/beam/core/runtime/exec/sdf_test.go | 17 +-
.../beam/runners/prism/internal/config/config.go | 2 +-
.../beam/runners/prism/internal/urns/urns_test.go | 2 +-
12 files changed, 1227 insertions(+), 314 deletions(-)
diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 907af4045b5..25b846370fb 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -866,14 +866,15 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error {
// CreateInitialRestriction.
if numMainIn == MainUnknown {
initialRestFn := fn.methods[createInitialRestrictionName]
- paramNum := len(initialRestFn.Param)
+ paramNum := len(initialRestFn.Params(funcx.FnValue))
+
switch paramNum {
case int(MainSingle), int(MainKv):
num = paramNum
default: // Can't infer because method has invalid # of main inputs.
- err := errors.Errorf("invalid number of params in method %v. got: %v, want: %v or %v",
+ err := errors.Errorf("invalid number of main input params in method %v. got: %v, want: %v or %v",
createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv))
- return errors.SetTopLevelMsgf(err, "Invalid number of parameters in method %v. "+
+ return errors.SetTopLevelMsgf(err, "Invalid number of main input parameters in method %v. "+
"Got: %v, Want: %v or %v. Check that the signature conforms to the expected signature for %v, "+
"and that elements in SDF method parameters match elements in %v.",
createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv), createInitialRestrictionName, processElementName)
@@ -894,7 +895,7 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error {
// in each SDF method in the given Fn, and returns an error if a method has an
// invalid/unexpected number.
func validateSdfSigNumbers(fn *Fn, num int) error {
- paramNums := map[string]int{
+ reqParamNums := map[string]int{
createInitialRestrictionName: num,
splitRestrictionName: num + 1,
restrictionSizeName: num + 1,
@@ -904,32 +905,52 @@ func validateSdfSigNumbers(fn *Fn, num int) error {
optionalSdfs := map[string]bool{
truncateRestrictionName: true,
}
- returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF methods.
+ reqReturnNum := 1
for _, name := range sdfNames {
method, ok := fn.methods[name]
if !ok && optionalSdfs[name] {
continue
}
- if len(method.Param) != paramNums[name] {
- err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v",
- name, len(method.Param), paramNums[name])
+
+ reqParamNum := reqParamNums[name]
+ if !sdfHasValidParamNum(method.Param, reqParamNum) {
+ err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v or optionally %v "+
+ "if first param is of type context.Context", name, len(method.Param), reqParamNum, reqParamNum+1)
return errors.SetTopLevelMsgf(err, "Unexpected number of parameters in method %v. "+
- "Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v, "+
- "and that elements in SDF method parameters match elements in %v.",
- name, len(method.Param), paramNums[name], name, processElementName)
+ "Got: %v, Want: %v or optionally %v if first param is of type context.Context. "+
+ "Check that the signature conforms to the expected signature for %v, and that elements in SDF method "+
+ "parameters match elements in %v.", name, len(method.Param), reqParamNum, reqParamNum+1,
+ name, processElementName)
}
- if len(method.Ret) != returnNum {
- err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v",
- name, len(method.Ret), returnNum)
+ if !sdfHasValidReturnNum(method.Ret, reqReturnNum) {
+ err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v or optionally %v "+
+ "if last value is of type error", name, len(method.Ret), reqReturnNum, reqReturnNum+1)
return errors.SetTopLevelMsgf(err, "Unexpected number of return values in method %v. "+
- "Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v.",
- name, len(method.Ret), returnNum, name)
+ "Got: %v, Want: %v or optionally %v if last value is of type error. "+
+ "Check that the signature conforms to the expected signature for %v.",
+ name, len(method.Ret), reqReturnNum, reqReturnNum+1, name)
}
}
return nil
}
+func sdfHasValidParamNum(params []funcx.FnParam, requiredNum int) bool {
+ if len(params) == requiredNum {
+ return true
+ }
+
+ return len(params) == requiredNum+1 && params[0].Kind == funcx.FnContext
+}
+
+func sdfHasValidReturnNum(returns []funcx.ReturnParam, requiredNum int) bool {
+ if len(returns) == requiredNum {
+ return true
+ }
+
+ return len(returns) == requiredNum+1 && returns[len(returns)-1].Kind == funcx.RetError
+}
+
// validateSdfSigTypes validates the types of the parameters and return values
// in each SDF method in the given Fn, and returns an error if a method has an
// invalid/mismatched type. Assumes that the number of parameters and return
@@ -940,22 +961,25 @@ func validateSdfSigTypes(fn *Fn, num int) error {
for _, name := range requiredSdfNames {
method := fn.methods[name]
+ startIdx := sdfRequiredParamStartIndex(method)
+
switch name {
case createInitialRestrictionName:
- if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, 0); err != nil {
+ if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, startIdx); err != nil {
return err
}
case splitRestrictionName:
- if err := validateSdfElementT(fn, splitRestrictionName, method, num, 0); err != nil {
+ if err := validateSdfElementT(fn, splitRestrictionName, method, num, startIdx); err != nil {
return err
}
- if method.Param[num].T != restrictionT {
+ idx := num + startIdx
+ if method.Param[idx].T != restrictionT {
err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
- splitRestrictionName, num, method.Param[num].T, restrictionT)
+ splitRestrictionName, idx, method.Param[idx].T, restrictionT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that all restrictions in an SDF are the same type.",
- splitRestrictionName, num, method.Param[num].T, restrictionT, createInitialRestrictionName)
+ splitRestrictionName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName)
}
if method.Ret[0].T.Kind() != reflect.Slice ||
method.Ret[0].T.Elem() != restrictionT {
@@ -967,16 +991,17 @@ func validateSdfSigTypes(fn *Fn, num int) error {
splitRestrictionName, 0, method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, splitRestrictionName)
}
case restrictionSizeName:
- if err := validateSdfElementT(fn, restrictionSizeName, method, num, 0); err != nil {
+ if err := validateSdfElementT(fn, restrictionSizeName, method, num, startIdx); err != nil {
return err
}
- if method.Param[num].T != restrictionT {
+ idx := num + startIdx
+ if method.Param[idx].T != restrictionT {
err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
- restrictionSizeName, num, method.Param[num].T, restrictionT)
+ restrictionSizeName, idx, method.Param[idx].T, restrictionT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that all restrictions in an SDF are the same type.",
- restrictionSizeName, num, method.Param[num].T, restrictionT, createInitialRestrictionName)
+ restrictionSizeName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName)
}
if method.Ret[0].T != reflectx.Float64 {
err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
@@ -986,13 +1011,13 @@ func validateSdfSigTypes(fn *Fn, num int) error {
restrictionSizeName, 0, method.Ret[0].T, reflectx.Float64)
}
case createTrackerName:
- if method.Param[0].T != restrictionT {
+ if method.Param[startIdx].T != restrictionT {
err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
- createTrackerName, 0, method.Param[0].T, restrictionT)
+ createTrackerName, startIdx, method.Param[startIdx].T, restrictionT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that all restrictions in an SDF are the same type.",
- createTrackerName, 0, method.Param[0].T, restrictionT, createInitialRestrictionName)
+ createTrackerName, startIdx, method.Param[startIdx].T, restrictionT, createInitialRestrictionName)
}
if !method.Ret[0].T.Implements(rTrackerT) {
err := errors.Errorf("invalid output type in method %v, return %v: %v does not implement sdf.RTracker",
@@ -1020,15 +1045,18 @@ func validateSdfSigTypes(fn *Fn, num int) error {
if !ok {
continue
}
+
+ startIdx := sdfRequiredParamStartIndex(method)
+
switch name {
case truncateRestrictionName:
- if method.Param[0].T != rTrackerImplT {
+ if method.Param[startIdx].T != rTrackerImplT {
err := errors.Errorf("mismatched restriction tracker type in method %v, param %v. got: %v, want: %v",
- truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT)
+ truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT)
return errors.SetTopLevelMsgf(err, "Mismatched restriction tracker type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that restriction tracker is the first parameter.",
- truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT, createTrackerName)
+ truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT, createTrackerName)
}
if method.Ret[0].T != restrictionT {
err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
@@ -1052,6 +1080,14 @@ func validateSdfSigTypes(fn *Fn, num int) error {
return nil
}
+func sdfRequiredParamStartIndex(method *funcx.Fn) int {
+ if ctxIndex, ok := method.Context(); ok {
+ return ctxIndex + 1
+ }
+
+ return 0
+}
+
// validateSdfElementT validates that element types in an SDF method are
// consistent with the ProcessElement method. This method assumes that the
// first 'num' parameters starting with startIndex are the elements.
@@ -1062,13 +1098,14 @@ func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int, startIn
pos, _, _ := processFn.Inputs()
for i := 0; i < num; i++ {
- if method.Param[i+startIndex].T != processFn.Param[pos+i].T {
+ idx := i + startIndex
+ if method.Param[idx].T != processFn.Param[pos+i].T {
err := errors.Errorf("mismatched element type in method %v, param %v. got: %v, want: %v",
- name, i, method.Param[i].T, processFn.Param[pos+i].T)
+ name, idx, method.Param[idx].T, processFn.Param[pos+i].T)
return errors.SetTopLevelMsgf(err, "Mismatched element type in method %v, "+
"parameter at index %v. Got: %v, Want: %v (from method %v). "+
"Ensure that element parameters in SDF methods have consistent types with element parameters in %v.",
- name, i, method.Param[i].T, processFn.Param[pos+i].T, processElementName, processElementName)
+ name, idx, method.Param[idx].T, processFn.Param[pos+i].T, processElementName, processElementName)
}
}
return nil
@@ -1178,7 +1215,8 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
// CreateInitialRestriction.
if numMainIn == int(MainUnknown) {
initialRestFn := fn.methods[createInitialRestrictionName]
- paramNum := len(initialRestFn.Param)
+ paramNum := len(initialRestFn.Params(funcx.FnValue))
+
switch paramNum {
case int(MainSingle), int(MainKv):
numMainIn = paramNum
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index d2f88a8a5ce..cf44761d4f3 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -190,6 +190,9 @@ func TestNewDoFnSdf(t *testing.T) {
}{
{dfn: &GoodSdf{}, main: MainSingle},
{dfn: &GoodSdfKv{}, main: MainKv},
+ {dfn: &GoodSdfWContext{}, main: MainSingle},
+ {dfn: &GoodSdfKvWContext{}, main: MainKv},
+ {dfn: &GoodSdfWErr{}, main: MainSingle},
{dfn: &GoodIgnoreOtherExportedMethods{}, main: MainSingle},
}
@@ -987,6 +990,90 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT {
return RestT{}
}
+type GoodSdfWContext struct {
+ *GoodDoFn
+}
+
+func (fn *GoodSdfWContext) CreateInitialRestriction(context.Context, int) RestT {
+ return RestT{}
+}
+
+func (fn *GoodSdfWContext) SplitRestriction(context.Context, int, RestT) []RestT {
+ return []RestT{}
+}
+
+func (fn *GoodSdfWContext) RestrictionSize(context.Context, int, RestT) float64 {
+ return 0
+}
+
+func (fn *GoodSdfWContext) CreateTracker(context.Context, RestT) *RTrackerT {
+ return &RTrackerT{}
+}
+
+func (fn *GoodSdfWContext) ProcessElement(context.Context, *RTrackerT, int) (int, sdf.ProcessContinuation) {
+ return 0, sdf.StopProcessing()
+}
+
+func (fn *GoodSdfWContext) TruncateRestriction(context.Context, *RTrackerT, int) RestT {
+ return RestT{}
+}
+
+type GoodSdfKvWContext struct {
+ *GoodDoFnKv
+}
+
+func (fn *GoodSdfKvWContext) CreateInitialRestriction(context.Context, int, int) RestT {
+ return RestT{}
+}
+
+func (fn *GoodSdfKvWContext) SplitRestriction(context.Context, int, int, RestT) []RestT {
+ return []RestT{}
+}
+
+func (fn *GoodSdfKvWContext) RestrictionSize(context.Context, int, int, RestT) float64 {
+ return 0
+}
+
+func (fn *GoodSdfKvWContext) CreateTracker(context.Context, RestT) *RTrackerT {
+ return &RTrackerT{}
+}
+
+func (fn *GoodSdfKvWContext) ProcessElement(context.Context, *RTrackerT, int, int) (int, sdf.ProcessContinuation) {
+ return 0, sdf.StopProcessing()
+}
+
+func (fn *GoodSdfKvWContext) TruncateRestriction(context.Context, *RTrackerT, int, int) RestT {
+ return RestT{}
+}
+
+type GoodSdfWErr struct {
+ *GoodDoFn
+}
+
+func (fn *GoodSdfWErr) CreateInitialRestriction(int) (RestT, error) {
+ return RestT{}, nil
+}
+
+func (fn *GoodSdfWErr) SplitRestriction(int, RestT) ([]RestT, error) {
+ return []RestT{}, nil
+}
+
+func (fn *GoodSdfWErr) RestrictionSize(int, RestT) (float64, error) {
+ return 0, nil
+}
+
+func (fn *GoodSdfWErr) CreateTracker(RestT) (*RTrackerT, error) {
+ return &RTrackerT{}, nil
+}
+
+func (fn *GoodSdfWErr) ProcessElement(*RTrackerT, int) (int, sdf.ProcessContinuation, error) {
+ return 0, sdf.StopProcessing(), nil
+}
+
+func (fn *GoodSdfWErr) TruncateRestriction(*RTrackerT, int) (RestT, error) {
+ return RestT{}, nil
+}
+
type GoodIgnoreOtherExportedMethods struct {
*GoodSdf
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
index 9c4de0564c8..a6347fc8d0e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
@@ -196,7 +196,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) {
// Check if there's a continuation and return residuals
// Needs to be done immeadiately after processing to not lose the element.
if c := n.getProcessContinuation(); c != nil {
- cp, err := n.checkpointThis(c)
+ cp, err := n.checkpointThis(ctx, c)
if err != nil {
// Errors during checkpointing should fail a bundle.
return nil, err
@@ -422,7 +422,7 @@ type Checkpoint struct {
// splittable or has not returned a resuming continuation, the function returns an empty
// SplitResult, a negative resumption time, and a false boolean to indicate that no split
// occurred.
-func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, error) {
+func (n *DataSource) checkpointThis(ctx context.Context, pc sdf.ProcessContinuation) (*Checkpoint, error) {
n.mu.Lock()
defer n.mu.Unlock()
@@ -435,7 +435,7 @@ func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, er
ow := su.GetOutputWatermark()
// Checkpointing is functionally a split at fraction 0.0
- rs, err := su.Checkpoint()
+ rs, err := su.Checkpoint(ctx)
if err != nil {
return nil, err
}
@@ -530,7 +530,7 @@ func (n *DataSource) Split(ctx context.Context, splits []int64, frac float64, bu
// Get the output watermark before splitting to avoid accidentally overestimating
ow := su.GetOutputWatermark()
// Otherwise, perform a sub-element split.
- ps, rs, err := su.Split(fr)
+ ps, rs, err := su.Split(ctx, fr)
if err != nil {
return SplitResult{}, err
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
index 64a37739b24..2da3284f016 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
@@ -631,7 +631,7 @@ type TestSplittableUnit struct {
// Split checks the input fraction for correctness, but otherwise always returns
// a successful split. The split elements are just copies of the original.
-func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *TestSplittableUnit) Split(_ context.Context, f float64) ([]*FullValue, []*FullValue, error) {
if f > 1.0 || f < 0.0 {
return nil, nil, errors.Errorf("Error")
}
@@ -639,8 +639,8 @@ func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, error
}
// Checkpoint routes through the Split() function to satisfy the interface.
-func (n *TestSplittableUnit) Checkpoint() ([]*FullValue, error) {
- _, r, err := n.Split(0.0)
+func (n *TestSplittableUnit) Checkpoint(ctx context.Context) ([]*FullValue, error) {
+ _, r, err := n.Split(ctx, 0.0)
return r, err
}
@@ -876,13 +876,13 @@ func TestSplitHelper(t *testing.T) {
func TestCheckpointing(t *testing.T) {
t.Run("nil", func(t *testing.T) {
- cps, err := (&DataSource{}).checkpointThis(nil)
+ cps, err := (&DataSource{}).checkpointThis(context.Background(), nil)
if err != nil {
t.Fatalf("checkpointThis() = %v, %v", cps, err)
}
})
t.Run("Stop", func(t *testing.T) {
- cps, err := (&DataSource{}).checkpointThis(sdf.StopProcessing())
+ cps, err := (&DataSource{}).checkpointThis(context.Background(), sdf.StopProcessing())
if err != nil {
t.Fatalf("checkpointThis() = %v, %v", cps, err)
}
@@ -899,7 +899,7 @@ func TestCheckpointing(t *testing.T) {
},
},
}
- cp, err := root.checkpointThis(sdf.ResumeProcessingIn(time.Second * 13))
+ cp, err := root.checkpointThis(context.Background(), sdf.ResumeProcessingIn(time.Second*13))
if err != nil {
t.Fatalf("checkpointThis() = %v, %v, want nil", cp, err)
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index e22496eae6e..1dd3e35dc4d 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -90,7 +90,11 @@ func (n *PairWithRestriction) StartBundle(ctx context.Context, id string, data D
// Timestamps
// }
func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
- rest := n.inv.Invoke(elm)
+ rest, err := n.inv.Invoke(ctx, elm)
+ if err != nil {
+ return err
+ }
+
output := FullValue{Elm: elm, Elm2: &FullValue{Elm: rest, Elm2: n.iwesInv.Invoke(rest, elm)}, Timestamp: elm.Timestamp, Windows: elm.Windows}
return n.Out.ProcessElement(ctx, &output, values...)
@@ -195,10 +199,17 @@ func (n *SplitAndSizeRestrictions) ProcessElement(ctx context.Context, elm *Full
// the element may not be wrapped in a *FullValue
mainElm := convertIfNeeded(elm.Elm, &FullValue{})
- splitRests := n.splitInv.Invoke(mainElm, rest)
+ splitRests, err := n.splitInv.Invoke(ctx, mainElm, rest)
+ if err != nil {
+ return err
+ }
for _, splitRest := range splitRests {
- size := n.sizeInv.Invoke(mainElm, splitRest)
+ size, err := n.sizeInv.Invoke(ctx, mainElm, splitRest)
+ if err != nil {
+ return err
+ }
+
if size < 0 {
err := errors.Errorf("size returned expected to be non-negative but received %v.", size)
return errors.WithContextf(err, "%v", n)
@@ -325,13 +336,25 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *Full
inp = e
}
rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
- rt := n.ctInv.Invoke(rest)
- newRest := n.truncateInv.Invoke(rt, mainElm)
+
+ rt, err := n.ctInv.Invoke(ctx, rest)
+ if err != nil {
+ return err
+ }
+
+ newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm)
+ if err != nil {
+ return err
+ }
if newRest == nil {
// do not propagate discarded restrictions.
return nil
}
- size := n.sizeInv.Invoke(mainElm, newRest)
+
+ size, err := n.sizeInv.Invoke(ctx, mainElm, newRest)
+ if err != nil {
+ return err
+ }
output := &FullValue{}
output.Timestamp = elm.Timestamp
@@ -476,7 +499,7 @@ func (n *ProcessSizedElementsAndRestrictions) StartBundle(ctx context.Context, i
// and processes each element using the underlying ParDo and adding the
// restriction tracker to the normal invocation. Sizing information is present
// but currently ignored. Output is forwarded to the underlying ParDo's outputs.
-func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context, elm *FullValue, values ...ReStream) error {
+func (n *ProcessSizedElementsAndRestrictions) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
if n.PDo.status != Active {
err := errors.Errorf("invalid status %v, want Active", n.PDo.status)
return errors.WithContextf(err, "%v", n)
@@ -520,7 +543,12 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
// If windows don't need to be exploded (i.e. aren't observed), treat
// all windows as one as an optimization.
rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
- rt := n.ctInv.Invoke(rest)
+
+ rt, err := n.ctInv.Invoke(ctx, rest)
+ if err != nil {
+ return err
+ }
+
mainIn.RTracker = rt
n.numW = 1 // Even if there's more than one window, treat them as one.
@@ -542,7 +570,12 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
for i := 0; i < n.numW; i++ {
rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
- rt := n.ctInv.Invoke(rest)
+
+ rt, err := n.ctInv.Invoke(ctx, rest)
+ if err != nil {
+ return err
+ }
+
key := &mainIn.Key
w := elm.Windows[i]
wElm := FullValue{Elm: key.Elm, Elm2: key.Elm2, Timestamp: key.Timestamp, Windows: []typex.Window{w}}
@@ -552,7 +585,7 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
n.elm = elm
n.SU <- n
// TODO(BEAM-11104): Remove placeholder for ProcessContinuation return.
- _, err := n.PDo.processSingleWindow(&MainInput{Key: wElm, Values: mainIn.Values, RTracker: rt})
+ _, err = n.PDo.processSingleWindow(&MainInput{Key: wElm, Values: mainIn.Values, RTracker: rt})
if err != nil {
<-n.SU
return n.PDo.fail(err)
@@ -596,13 +629,13 @@ type SplittableUnit interface {
//
// More than one primary/residual can happen if the split result cannot be
// fully represented in just one.
- Split(fraction float64) (primaries, residuals []*FullValue, err error)
+ Split(ctx context.Context, fraction float64) (primaries, residuals []*FullValue, err error)
// Checkpoint performs a split at fraction 0.0 of an element that has stopped
// processing and has work that needs to be resumed later. This function will
// check that the produced primary restriction from the split represents
// completed work to avoid data loss and will error if work remains.
- Checkpoint() (residuals []*FullValue, err error)
+ Checkpoint(ctx context.Context) (residuals []*FullValue, err error)
// GetProgress returns the fraction of progress the current element has
// made in processing. (ex. 0.0 means no progress, and 1.0 means fully
@@ -631,7 +664,7 @@ type SplittableUnit interface {
// windows need to be taken into account. For implementation details on when
// each case occurs and the implementation details, see the documentation for
// the singleWindowSplit and multiWindowSplit methods.
-func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) Split(ctx context.Context, f float64) ([]*FullValue, []*FullValue, error) {
// Get the watermark state immediately so that we don't overestimate our current watermark.
var pWeState any
var rWeState any
@@ -658,7 +691,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
// Split behavior differs depending on whether this is a window-observing
// DoFn or not.
if len(n.elm.Windows) > 1 {
- p, r, err := n.multiWindowSplit(f, pWeState, rWeState)
+ p, r, err := n.multiWindowSplit(ctx, f, pWeState, rWeState)
if err != nil {
return nil, nil, addContext(err)
}
@@ -666,7 +699,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
}
// Not window-observing, or window-observing but only one window.
- p, r, err := n.singleWindowSplit(f, pWeState, rWeState)
+ p, r, err := n.singleWindowSplit(ctx, f, pWeState, rWeState)
if err != nil {
return nil, nil, addContext(err)
}
@@ -677,11 +710,11 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
// later by the runner. This is done iff the underlying Splittable DoFn returns a resuming
// ProcessContinuation. If the split occurs and the primary restriction is marked as done
// my the RTracker, the Checkpoint fails as this is a potential data-loss case.
-func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) Checkpoint(ctx context.Context) ([]*FullValue, error) {
addContext := func(err error) error {
return errors.WithContext(err, "Attempting checkpoint in ProcessSizedElementsAndRestrictions")
}
- _, r, err := n.Split(0.0)
+ _, r, err := n.Split(ctx, 0.0)
if err != nil {
return nil, addContext(err)
@@ -699,7 +732,7 @@ func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, error)
// behavior is identical). A single restriction split will occur and all windows
// present in the unsplit element will be present in both the resulting primary
// and residual.
-func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeState, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(ctx context.Context, f float64, pWeState, rWeState any) ([]*FullValue, []*FullValue, error) {
if n.rt.IsDone() { // Not an error, but not splittable.
return []*FullValue{}, []*FullValue{}, nil
}
@@ -714,14 +747,14 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeSt
var primaryResult []*FullValue
if p != nil {
- pfv, err := n.newSplitResult(p, n.elm.Windows, pWeState)
+ pfv, err := n.newSplitResult(ctx, p, n.elm.Windows, pWeState)
if err != nil {
return nil, nil, err
}
primaryResult = append(primaryResult, pfv)
}
- rfv, err := n.newSplitResult(r, n.elm.Windows, rWeState)
+ rfv, err := n.newSplitResult(ctx, r, n.elm.Windows, rWeState)
if err != nil {
return nil, nil, err
}
@@ -752,7 +785,7 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeSt
//
// This method also updates the current number of windows (n.numW) so that
// windows in the residual will no longer be processed.
-func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(ctx context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
// Get the split point in window range, to see what window it falls in.
done, rem := n.rt.GetProgress()
cwp := done / (done + rem) // Progress in current window.
@@ -765,25 +798,25 @@ func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeSta
if n.rt.IsDone() {
// Current RTracker is done so we can't split within the window, so
// split at window boundary instead.
- return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
+ return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState)
}
// Get the fraction of remaining work in the current window to split at.
cwsp := wsp - float64(n.currW) // Split point in current window.
rf := (cwsp - cwp) / (1 - cwp) // Fraction of work in RTracker to split at.
- return n.currentWindowSplit(rf, pWeState, rWeState)
+ return n.currentWindowSplit(ctx, rf, pWeState, rWeState)
} else {
// Split at nearest window boundary to split point.
wb := math.Round(wsp)
- return n.windowBoundarySplit(int(wb), pWeState, rWeState)
+ return n.windowBoundarySplit(ctx, int(wb), pWeState, rWeState)
}
}
// currentWindowSplit performs an appropriate split at the given fraction of
// remaining work in the current window. Also updates numW to stop after the
// current window.
-func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(ctx context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
p, r, err := n.rt.TrySplit(f)
if err != nil {
return nil, nil, err
@@ -791,18 +824,18 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS
if r == nil {
// If r is nil then the split failed/returned an empty residual, but
// we can still split at a window boundary.
- return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
+ return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState)
}
// Split of currently processing restriction in a single window.
ps := make([]*FullValue, 1)
- newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1], pWeState)
+ newP, err := n.newSplitResult(ctx, p, n.elm.Windows[n.currW:n.currW+1], pWeState)
if err != nil {
return nil, nil, err
}
ps[0] = newP
rs := make([]*FullValue, 1)
- newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1], rWeState)
+ newR, err := n.newSplitResult(ctx, r, n.elm.Windows[n.currW:n.currW+1], rWeState)
if err != nil {
return nil, nil, err
}
@@ -810,14 +843,14 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS
// Window boundary split surrounding the split restriction above.
full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
if 0 < n.currW {
- newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW], pWeState)
+ newP, err := n.newSplitResult(ctx, full, n.elm.Windows[0:n.currW], pWeState)
if err != nil {
return nil, nil, err
}
ps = append(ps, newP)
}
if n.currW+1 < n.numW {
- newR, err := n.newSplitResult(full, n.elm.Windows[n.currW+1:n.numW], rWeState)
+ newR, err := n.newSplitResult(ctx, full, n.elm.Windows[n.currW+1:n.numW], rWeState)
if err != nil {
return nil, nil, err
}
@@ -830,17 +863,17 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS
// windowBoundarySplit performs an appropriate split at a window boundary. The
// split point taken should be the index of the first window in the residual.
// Also updates numW to stop at the split point.
-func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(ctx context.Context, splitPt int, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
// If this is at the boundary of the last window, split is a no-op.
if splitPt == n.numW {
return []*FullValue{}, []*FullValue{}, nil
}
full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
- pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt], pWeState)
+ pFv, err := n.newSplitResult(ctx, full, n.elm.Windows[0:splitPt], pWeState)
if err != nil {
return nil, nil, err
}
- rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW], rWeState)
+ rFv, err := n.newSplitResult(ctx, full, n.elm.Windows[splitPt:n.numW], rWeState)
if err != nil {
return nil, nil, err
}
@@ -852,18 +885,27 @@ func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, p
// element restriction pair based on the currently processing element, but with
// a modified restriction and windows. Intended for creating primaries and
// residuals to return as split results.
-func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest any, w []typex.Window, weState any) (*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) newSplitResult(ctx context.Context, rest any, w []typex.Window, weState any) (*FullValue, error) {
var size float64
+ var err error
elm := n.elm.Elm.(*FullValue).Elm
if fv, ok := elm.(*FullValue); ok {
- size = n.sizeInv.Invoke(fv, rest)
+ size, err = n.sizeInv.Invoke(ctx, fv, rest)
+ if err != nil {
+ return nil, err
+ }
+
if size < 0 {
err := errors.Errorf("size returned expected to be non-negative but received %v.", size)
return nil, errors.WithContextf(err, "%v", n)
}
} else {
fv := &FullValue{Elm: elm}
- size = n.sizeInv.Invoke(fv, rest)
+ size, err = n.sizeInv.Invoke(ctx, fv, rest)
+ if err != nil {
+ return nil, err
+ }
+
if size < 0 {
err := errors.Errorf("size returned expected to be non-negative but received %v.", size)
return nil, errors.WithContextf(err, "%v", n)
@@ -973,21 +1015,33 @@ func (n *SdfFallback) StartBundle(ctx context.Context, id string, data DataConte
// restrictions, and then creating restriction trackers and processing each
// restriction with the underlying ParDo. This executor skips the sizing step
// because sizing information is unnecessary for unexpanded SDFs.
-func (n *SdfFallback) ProcessElement(_ context.Context, elm *FullValue, values ...ReStream) error {
+func (n *SdfFallback) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
if n.PDo.status != Active {
err := errors.Errorf("invalid status %v, want Active", n.PDo.status)
return errors.WithContextf(err, "%v", n)
}
- rest := n.initRestInv.Invoke(elm)
- splitRests := n.splitInv.Invoke(elm, rest)
+ rest, err := n.initRestInv.Invoke(ctx, elm)
+ if err != nil {
+ return err
+ }
+
+ splitRests, err := n.splitInv.Invoke(ctx, elm, rest)
+ if err != nil {
+ return err
+ }
+
if len(splitRests) == 0 {
err := errors.Errorf("initial splitting returned 0 restrictions.")
return errors.WithContextf(err, "%v", n)
}
for _, splitRest := range splitRests {
- rt := n.trackerInv.Invoke(splitRest)
+ rt, err := n.trackerInv.Invoke(ctx, splitRest)
+ if err != nil {
+ return err
+ }
+
mainIn := &MainInput{
Key: *elm,
Values: values,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
index 5d58f198a64..2dd894ed08c 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
@@ -16,6 +16,7 @@
package exec
import (
+ "context"
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
@@ -24,6 +25,9 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
)
+//go:generate specialize --input=sdf_invokers_arity.tmpl
+//go:generate gofmt -w sdf_invokers_arity.go
+
// This file contains invokers for SDF methods. These invokers are based off
// exec.invoker which is used for regular DoFns. Since exec.invoker is
// specialized for DoFns it cannot be used for SDF methods. Instead, these
@@ -39,9 +43,10 @@ import (
// cirInvoker is an invoker for CreateInitialRestriction.
type cirInvoker struct {
- fn *funcx.Fn
- args []any // Cache to avoid allocating new slices per-element.
- call func(elms *FullValue) (rest any)
+ fn *funcx.Fn
+ args []any // Cache to avoid allocating new slices per-element.
+ ctxIdx int
+ call func() (rest any, err error)
}
func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) {
@@ -49,51 +54,34 @@ func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) {
fn: fn,
args: make([]any, len(fn.Param)),
}
+
+ var ok bool
+ if n.ctxIdx, ok = fn.Context(); !ok {
+ n.ctxIdx = -1
+ }
+
if err := n.initCallFn(); err != nil {
return nil, errors.WithContext(err, "sdf CreateInitialRestriction invoker")
}
return n, nil
}
-func (n *cirInvoker) initCallFn() error {
- // Expects a signature of the form:
- // (key?, value) restriction
- // TODO(BEAM-9643): Link to full documentation.
- switch fnT := n.fn.Fn.(type) {
- case reflectx.Func1x1:
- n.call = func(elms *FullValue) any {
- return fnT.Call1x1(elms.Elm)
- }
- case reflectx.Func2x1:
- n.call = func(elms *FullValue) any {
- return fnT.Call2x1(elms.Elm, elms.Elm2)
- }
- default:
- switch len(n.fn.Param) {
- case 1:
- n.call = func(elms *FullValue) any {
- n.args[0] = elms.Elm
- return n.fn.Fn.Call(n.args)[0]
- }
- case 2:
- n.call = func(elms *FullValue) any {
- n.args[0] = elms.Elm
- n.args[1] = elms.Elm2
- return n.fn.Fn.Call(n.args)[0]
- }
- default:
- return errors.Errorf("CreateInitialRestriction fn %v has unexpected number of parameters: %v",
- n.fn.Fn.Name(), len(n.fn.Param))
- }
+// Invoke calls CreateInitialRestriction with the given FullValue as the element
+// and returns the resulting restriction.
+func (n *cirInvoker) Invoke(ctx context.Context, elms *FullValue) (rest any, err error) {
+ if n.ctxIdx >= 0 {
+ n.args[n.ctxIdx] = ctx
}
- return nil
-}
+ i := n.ctxIdx + 1
+ n.args[i] = elms.Elm
-// Invoke calls CreateInitialRestriction with the given FullValue as the element
-// and returns the resulting restriction.
-func (n *cirInvoker) Invoke(elms *FullValue) (rest any) {
- return n.call(elms)
+ if elms.Elm2 != nil {
+ i++
+ n.args[i] = elms.Elm2
+ }
+
+ return n.call()
}
// Reset zeroes argument entries in the cached slice to allow values to be
@@ -106,9 +94,10 @@ func (n *cirInvoker) Reset() {
// srInvoker is an invoker for SplitRestriction.
type srInvoker struct {
- fn *funcx.Fn
- args []any // Cache to avoid allocating new slices per-element.
- call func(elms *FullValue, rest any) (splits any)
+ fn *funcx.Fn
+ args []any // Cache to avoid allocating new slices per-element.
+ ctxIdx int
+ call func() (splits any, err error)
}
func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) {
@@ -116,52 +105,40 @@ func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) {
fn: fn,
args: make([]any, len(fn.Param)),
}
+
+ var ok bool
+ if n.ctxIdx, ok = fn.Context(); !ok {
+ n.ctxIdx = -1
+ }
+
if err := n.initCallFn(); err != nil {
return nil, errors.WithContext(err, "sdf SplitRestriction invoker")
}
return n, nil
}
-func (n *srInvoker) initCallFn() error {
- // Expects a signature of the form:
- // (key?, value, restriction) []restriction
- // TODO(BEAM-9643): Link to full documentation.
- switch fnT := n.fn.Fn.(type) {
- case reflectx.Func2x1:
- n.call = func(elms *FullValue, rest any) any {
- return fnT.Call2x1(elms.Elm, rest)
- }
- case reflectx.Func3x1:
- n.call = func(elms *FullValue, rest any) any {
- return fnT.Call3x1(elms.Elm, elms.Elm2, rest)
- }
- default:
- switch len(n.fn.Param) {
- case 2:
- n.call = func(elms *FullValue, rest any) any {
- n.args[0] = elms.Elm
- n.args[1] = rest
- return n.fn.Fn.Call(n.args)[0]
- }
- case 3:
- n.call = func(elms *FullValue, rest any) any {
- n.args[0] = elms.Elm
- n.args[1] = elms.Elm2
- n.args[2] = rest
- return n.fn.Fn.Call(n.args)[0]
- }
- default:
- return errors.Errorf("SplitRestriction fn %v has unexpected number of parameters: %v",
- n.fn.Fn.Name(), len(n.fn.Param))
- }
- }
- return nil
-}
-
// Invoke calls SplitRestriction given a FullValue containing an element and
// the associated restriction, and returns a slice of split restrictions.
-func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) {
- ret := n.call(elms, rest)
+func (n *srInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) (splits []any, err error) {
+ if n.ctxIdx >= 0 {
+ n.args[n.ctxIdx] = ctx
+ }
+
+ i := n.ctxIdx + 1
+ n.args[i] = elms.Elm
+
+ if elms.Elm2 != nil {
+ i++
+ n.args[i] = elms.Elm2
+ }
+
+ i++
+ n.args[i] = rest
+
+ ret, err := n.call()
+ if err != nil {
+ return nil, err
+ }
// Return value is an any, but we need to convert it to a []any.
val := reflect.ValueOf(ret)
@@ -169,7 +146,7 @@ func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) {
for i := 0; i < val.Len(); i++ {
s = append(s, val.Index(i).Interface())
}
- return s
+ return s, nil
}
// Reset zeroes argument entries in the cached slice to allow values to be
@@ -182,9 +159,10 @@ func (n *srInvoker) Reset() {
// rsInvoker is an invoker for RestrictionSize.
type rsInvoker struct {
- fn *funcx.Fn
- args []any // Cache to avoid allocating new slices per-element.
- call func(elms *FullValue, rest any) (size float64)
+ fn *funcx.Fn
+ args []any // Cache to avoid allocating new slices per-element.
+ ctxIdx int
+ call func() (size float64, err error)
}
func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) {
@@ -192,52 +170,37 @@ func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) {
fn: fn,
args: make([]any, len(fn.Param)),
}
+
+ var ok bool
+ if n.ctxIdx, ok = fn.Context(); !ok {
+ n.ctxIdx = -1
+ }
+
if err := n.initCallFn(); err != nil {
return nil, errors.WithContext(err, "sdf RestrictionSize invoker")
}
return n, nil
}
-func (n *rsInvoker) initCallFn() error {
- // Expects a signature of the form:
- // (key?, value, restriction) float64
- // TODO(BEAM-9643): Link to full documentation.
- switch fnT := n.fn.Fn.(type) {
- case reflectx.Func2x1:
- n.call = func(elms *FullValue, rest any) float64 {
- return fnT.Call2x1(elms.Elm, rest).(float64)
- }
- case reflectx.Func3x1:
- n.call = func(elms *FullValue, rest any) float64 {
- return fnT.Call3x1(elms.Elm, elms.Elm2, rest).(float64)
- }
- default:
- switch len(n.fn.Param) {
- case 2:
- n.call = func(elms *FullValue, rest any) float64 {
- n.args[0] = elms.Elm
- n.args[1] = rest
- return n.fn.Fn.Call(n.args)[0].(float64)
- }
- case 3:
- n.call = func(elms *FullValue, rest any) float64 {
- n.args[0] = elms.Elm
- n.args[1] = elms.Elm2
- n.args[2] = rest
- return n.fn.Fn.Call(n.args)[0].(float64)
- }
- default:
- return errors.Errorf("RestrictionSize fn %v has unexpected number of parameters: %v",
- n.fn.Fn.Name(), len(n.fn.Param))
- }
- }
- return nil
-}
-
// Invoke calls RestrictionSize given a FullValue containing an element and
// the associated restriction, and returns a size.
-func (n *rsInvoker) Invoke(elms *FullValue, rest any) (size float64) {
- return n.call(elms, rest)
+func (n *rsInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) (size float64, err error) {
+ if n.ctxIdx >= 0 {
+ n.args[n.ctxIdx] = ctx
+ }
+
+ i := n.ctxIdx + 1
+ n.args[i] = elms.Elm
+
+ if elms.Elm2 != nil {
+ i++
+ n.args[i] = elms.Elm2
+ }
+
+ i++
+ n.args[i] = rest
+
+ return n.call()
}
// Reset zeroes argument entries in the cached slice to allow values to be
@@ -250,9 +213,10 @@ func (n *rsInvoker) Reset() {
// ctInvoker is an invoker for CreateTracker.
type ctInvoker struct {
- fn *funcx.Fn
- args []any // Cache to avoid allocating new slices per-element.
- call func(rest any) sdf.RTracker
+ fn *funcx.Fn
+ args []any // Cache to avoid allocating new slices per-element.
+ ctxIdx int
+ call func() (rt sdf.RTracker, err error)
}
func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) {
@@ -260,37 +224,27 @@ func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) {
fn: fn,
args: make([]any, len(fn.Param)),
}
+
+ var ok bool
+ if n.ctxIdx, ok = fn.Context(); !ok {
+ n.ctxIdx = -1
+ }
+
if err := n.initCallFn(); err != nil {
return nil, errors.WithContext(err, "sdf CreateTracker invoker")
}
return n, nil
}
-func (n *ctInvoker) initCallFn() error {
- // Expects a signature of the form:
- // (restriction) sdf.RTracker
- // TODO(BEAM-9643): Link to full documentation.
- switch fnT := n.fn.Fn.(type) {
- case reflectx.Func1x1:
- n.call = func(rest any) sdf.RTracker {
- return fnT.Call1x1(rest).(sdf.RTracker)
- }
- default:
- if len(n.fn.Param) != 1 {
- return errors.Errorf("CreateTracker fn %v has unexpected number of parameters: %v",
- n.fn.Fn.Name(), len(n.fn.Param))
- }
- n.call = func(rest any) sdf.RTracker {
- n.args[0] = rest
- return n.fn.Fn.Call(n.args)[0].(sdf.RTracker)
- }
+// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker.
+func (n *ctInvoker) Invoke(ctx context.Context, rest any) (sdf.RTracker, error) {
+ if n.ctxIdx >= 0 {
+ n.args[n.ctxIdx] = ctx
}
- return nil
-}
-// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker.
-func (n *ctInvoker) Invoke(rest any) sdf.RTracker {
- return n.call(rest)
+ n.args[n.ctxIdx+1] = rest
+
+ return n.call()
}
// Reset zeroes argument entries in the cached slice to allow values to be
@@ -303,9 +257,10 @@ func (n *ctInvoker) Reset() {
// trInvoker is an invoker for TruncateRestriction.
type trInvoker struct {
- fn *funcx.Fn
- args []any
- call func(rest any, elms *FullValue) (pair any)
+ fn *funcx.Fn
+ args []any
+ ctxIdx int
+ call func() (rest any, err error)
}
func defaultTruncateRestriction(restTracker any) (newRest any) {
@@ -320,6 +275,12 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) {
fn: fn,
args: make([]any, len(fn.Param)),
}
+
+ var ok bool
+ if n.ctxIdx, ok = fn.Context(); !ok {
+ n.ctxIdx = -1
+ }
+
if err := n.initCallFn(); err != nil {
return nil, errors.WithContext(err, "sdf TruncateRestriction invoker")
}
@@ -327,53 +288,39 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) {
}
func newDefaultTruncateRestrictionInvoker() (*trInvoker, error) {
- n := &trInvoker{}
- n.call = func(rest any, elms *FullValue) any {
- return defaultTruncateRestriction(rest)
+ n := &trInvoker{
+ args: make([]any, 1),
}
- return n, nil
-}
-
-func (n *trInvoker) initCallFn() error {
- // Expects a signature of the form:
- // (key?, value, restriction) []restriction
- // TODO(BEAM-9643): Link to full documentation.
- switch fnT := n.fn.Fn.(type) {
- case reflectx.Func2x1:
- n.call = func(rest any, elms *FullValue) any {
- return fnT.Call2x1(rest, elms.Elm)
- }
- case reflectx.Func3x1:
- n.call = func(rest any, elms *FullValue) any {
- return fnT.Call3x1(rest, elms.Elm, elms.Elm2)
- }
- default:
- switch len(n.fn.Param) {
- case 2:
- n.call = func(rest any, elms *FullValue) any {
- n.args[0] = rest
- n.args[1] = elms.Elm
- return n.fn.Fn.Call(n.args)[0]
- }
- case 3:
- n.call = func(rest any, elms *FullValue) any {
- n.args[0] = rest
- n.args[1] = elms.Elm
- n.args[2] = elms.Elm2
- return n.fn.Fn.Call(n.args)[0]
- }
- default:
- return errors.Errorf("TruncateRestriction fn %v has unexpected number of parameters: %v",
- n.fn.Fn.Name(), len(n.fn.Param))
- }
+ n.call = func() (any, error) {
+ return defaultTruncateRestriction(n.args[0]), nil
}
- return nil
+ return n, nil
}
// Invoke calls TruncateRestriction given a FullValue containing an element and
// the associated restriction tracker, and returns a truncated restriction.
-func (n *trInvoker) Invoke(rt any, elms *FullValue) (rest any) {
- return n.call(rt, elms)
+func (n *trInvoker) Invoke(ctx context.Context, rt any, elms *FullValue) (rest any, err error) {
+ if n.fn == nil {
+ n.args[0] = rt
+ return n.call()
+ }
+
+ if n.ctxIdx >= 0 {
+ n.args[n.ctxIdx] = ctx
+ }
+
+ i := n.ctxIdx + 1
+ n.args[i] = rt
+
+ i++
+ n.args[i] = elms.Elm
+
+ if elms.Elm2 != nil {
+ i++
+ n.args[i] = elms.Elm2
+ }
+
+ return n.call()
}
// Reset zeroes argument entries in the cached slice to allow values to be
@@ -589,3 +536,11 @@ func (n *wesInvoker) Reset() {
n.args[i] = nil
}
}
+
+func asError(val any) error {
+ if val != nil {
+ return val.(error)
+ }
+
+ return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go
new file mode 100644
index 00000000000..cdefa711603
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go
@@ -0,0 +1,336 @@
+// File generated by specialize. Do not edit.
+
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT.
+
+package exec
+
+import (
+ "fmt"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+)
+
+func (n *cirInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, key?, value) (restriction, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+
+ case reflectx.Func1x1:
+ n.call = func() (rest any, err error) {
+ r0 := fnT.Call1x1(n.args[0])
+ return r0, nil
+ }
+
+ case reflectx.Func2x1:
+ n.call = func() (rest any, err error) {
+ r0 := fnT.Call2x1(n.args[0], n.args[1])
+ return r0, nil
+ }
+
+ case reflectx.Func3x1:
+ n.call = func() (rest any, err error) {
+ r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+ return r0, nil
+ }
+
+ case reflectx.Func1x2:
+ n.call = func() (rest any, err error) {
+ r0, r1 := fnT.Call1x2(n.args[0])
+ return r0, asError(r1)
+ }
+
+ case reflectx.Func2x2:
+ n.call = func() (rest any, err error) {
+ r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+ return r0, asError(r1)
+ }
+
+ case reflectx.Func3x2:
+ n.call = func() (rest any, err error) {
+ r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+ return r0, asError(r1)
+ }
+
+ default:
+ if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 {
+ return errors.Errorf("CreateInitialRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (rest any, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0], nil
+ case 2:
+ return ret[0], asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("CreateInitialRestriction has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *srInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, key?, value, restriction) ([]restriction, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+
+ case reflectx.Func2x1:
+ n.call = func() (splits any, err error) {
+ r0 := fnT.Call2x1(n.args[0], n.args[1])
+ return r0, nil
+ }
+
+ case reflectx.Func3x1:
+ n.call = func() (splits any, err error) {
+ r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+ return r0, nil
+ }
+
+ case reflectx.Func4x1:
+ n.call = func() (splits any, err error) {
+ r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3])
+ return r0, nil
+ }
+
+ case reflectx.Func2x2:
+ n.call = func() (splits any, err error) {
+ r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+ return r0, asError(r1)
+ }
+
+ case reflectx.Func3x2:
+ n.call = func() (splits any, err error) {
+ r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+ return r0, asError(r1)
+ }
+
+ case reflectx.Func4x2:
+ n.call = func() (splits any, err error) {
+ r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3])
+ return r0, asError(r1)
+ }
+
+ default:
+ if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+ return errors.Errorf("SplitRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (splits any, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0], nil
+ case 2:
+ return ret[0], asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("SplitRestriction has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *rsInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, key?, value, restriction) (float64, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+
+ case reflectx.Func2x1:
+ n.call = func() (size float64, err error) {
+ r0 := fnT.Call2x1(n.args[0], n.args[1])
+ return r0.(float64), nil
+ }
+
+ case reflectx.Func3x1:
+ n.call = func() (size float64, err error) {
+ r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+ return r0.(float64), nil
+ }
+
+ case reflectx.Func4x1:
+ n.call = func() (size float64, err error) {
+ r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3])
+ return r0.(float64), nil
+ }
+
+ case reflectx.Func2x2:
+ n.call = func() (size float64, err error) {
+ r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+ return r0.(float64), asError(r1)
+ }
+
+ case reflectx.Func3x2:
+ n.call = func() (size float64, err error) {
+ r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+ return r0.(float64), asError(r1)
+ }
+
+ case reflectx.Func4x2:
+ n.call = func() (size float64, err error) {
+ r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3])
+ return r0.(float64), asError(r1)
+ }
+
+ default:
+ if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+ return errors.Errorf("RestrictionSize has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (size float64, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0].(float64), nil
+ case 2:
+ return ret[0].(float64), asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("RestrictionSize has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *ctInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, restriction) (sdf.RTracker, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+
+ case reflectx.Func1x1:
+ n.call = func() (rt sdf.RTracker, err error) {
+ r0 := fnT.Call1x1(n.args[0])
+ return r0.(sdf.RTracker), nil
+ }
+
+ case reflectx.Func2x1:
+ n.call = func() (rt sdf.RTracker, err error) {
+ r0 := fnT.Call2x1(n.args[0], n.args[1])
+ return r0.(sdf.RTracker), nil
+ }
+
+ case reflectx.Func1x2:
+ n.call = func() (rt sdf.RTracker, err error) {
+ r0, r1 := fnT.Call1x2(n.args[0])
+ return r0.(sdf.RTracker), asError(r1)
+ }
+
+ case reflectx.Func2x2:
+ n.call = func() (rt sdf.RTracker, err error) {
+ r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+ return r0.(sdf.RTracker), asError(r1)
+ }
+
+ default:
+ if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 {
+ return errors.Errorf("CreateTracker has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (rt sdf.RTracker, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0].(sdf.RTracker), nil
+ case 2:
+ return ret[0].(sdf.RTracker), asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("CreateTracker has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *trInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, sdf.RTracker, key?, value) (restriction, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+
+ case reflectx.Func2x1:
+ n.call = func() (rest any, err error) {
+ r0 := fnT.Call2x1(n.args[0], n.args[1])
+ return r0, nil
+ }
+
+ case reflectx.Func3x1:
+ n.call = func() (rest any, err error) {
+ r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+ return r0, nil
+ }
+
+ case reflectx.Func4x1:
+ n.call = func() (rest any, err error) {
+ r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3])
+ return r0, nil
+ }
+
+ case reflectx.Func2x2:
+ n.call = func() (rest any, err error) {
+ r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+ return r0, asError(r1)
+ }
+
+ case reflectx.Func3x2:
+ n.call = func() (rest any, err error) {
+ r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+ return r0, asError(r1)
+ }
+
+ case reflectx.Func4x2:
+ n.call = func() (rest any, err error) {
+ r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3])
+ return r0, asError(r1)
+ }
+
+ default:
+ if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+ return errors.Errorf("TruncateRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (rest any, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0], nil
+ case 2:
+ return ret[0], asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("TruncateRestriction has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl
new file mode 100644
index 00000000000..7df994be0c7
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT.
+
+package exec
+
+import (
+ "fmt"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+)
+
+func (n *cirInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, key?, value) (restriction, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 4}}
+ {{if gt $out 0}}
+ {{if gt $in 0}}
+ case reflectx.Func{{$in}}x{{$out}}:
+ n.call = func() (rest any, err error) {
+ {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+ {{- if eq $out 1}}
+ return r0, nil
+ {{- else}}
+ return r0, asError(r1)
+ {{- end}}
+ }
+ {{end}}
+ {{end}}
+{{end}}
+{{end}}
+ default:
+ if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 {
+ return errors.Errorf("CreateInitialRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (rest any, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0], nil
+ case 2:
+ return ret[0], asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("CreateInitialRestriction has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *srInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, key?, value, restriction) ([]restriction, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+ {{if gt $out 0}}
+ {{if gt $in 1}}
+ case reflectx.Func{{$in}}x{{$out}}:
+ n.call = func() (splits any, err error) {
+ {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+ {{- if eq $out 1}}
+ return r0, nil
+ {{- else}}
+ return r0, asError(r1)
+ {{- end}}
+ }
+ {{end}}
+ {{end}}
+{{end}}
+{{end}}
+ default:
+ if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+ return errors.Errorf("SplitRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (splits any, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0], nil
+ case 2:
+ return ret[0], asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("SplitRestriction has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *rsInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, key?, value, restriction) (float64, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+ {{if gt $out 0}}
+ {{if gt $in 1}}
+ case reflectx.Func{{$in}}x{{$out}}:
+ n.call = func() (size float64, err error) {
+ {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+ {{- if eq $out 1}}
+ return r0.(float64), nil
+ {{- else}}
+ return r0.(float64), asError(r1)
+ {{- end}}
+ }
+ {{end}}
+ {{end}}
+{{end}}
+{{end}}
+ default:
+ if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+ return errors.Errorf("RestrictionSize has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (size float64, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0].(float64), nil
+ case 2:
+ return ret[0].(float64), asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("RestrictionSize has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *ctInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, restriction) (sdf.RTracker, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 3}}
+ {{if gt $out 0}}
+ {{if gt $in 0}}
+ case reflectx.Func{{$in}}x{{$out}}:
+ n.call = func() (rt sdf.RTracker, err error) {
+ {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+ {{- if eq $out 1}}
+ return r0.(sdf.RTracker), nil
+ {{- else}}
+ return r0.(sdf.RTracker), asError(r1)
+ {{- end}}
+ }
+ {{end}}
+ {{end}}
+{{end}}
+{{end}}
+ default:
+ if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 {
+ return errors.Errorf("CreateTracker has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (rt sdf.RTracker, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0].(sdf.RTracker), nil
+ case 2:
+ return ret[0].(sdf.RTracker), asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("CreateTracker has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
+
+func (n *trInvoker) initCallFn() error {
+ // Expects a signature of the form:
+ // (context.Context?, sdf.RTracker, key?, value) (restriction, error?)
+ // TODO(BEAM-9643): Link to full documentation.
+ switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+ {{if gt $out 0}}
+ {{if gt $in 1}}
+ case reflectx.Func{{$in}}x{{$out}}:
+ n.call = func() (rest any, err error) {
+ {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+ {{- if eq $out 1}}
+ return r0, nil
+ {{- else}}
+ return r0, asError(r1)
+ {{- end}}
+ }
+ {{end}}
+ {{end}}
+{{end}}
+{{end}}
+ default:
+ if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+ return errors.Errorf("TruncateRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+ }
+
+ n.call = func() (rest any, err error) {
+ ret := n.fn.Fn.Call(n.args)
+
+ switch len(ret) {
+ case 1:
+ return ret[0], nil
+ case 2:
+ return ret[0], asError(ret[1])
+ }
+
+ panic(fmt.Sprintf("TruncateRestriction has unexpected number of return values: %v", len(ret)))
+ }
+ }
+
+ return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
index e308071aaf9..edef16e51d7 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
@@ -16,6 +16,8 @@
package exec
import (
+ "context"
+ "errors"
"testing"
"time"
@@ -48,13 +50,59 @@ func TestInvokes(t *testing.T) {
}
statefulWeFn := (*graph.SplittableDoFn)(dfn)
+ initialRestErrDfn, err := graph.NewDoFn(
+ &VetCreateInitialRestrictionErrSdf{},
+ graph.NumMainInputs(graph.MainSingle),
+ )
+ if err != nil {
+ t.Fatalf("invalid function: %v", err)
+ }
+ initialRestErrSdf := (*graph.SplittableDoFn)(initialRestErrDfn)
+
+ splitRestErrDfn, err := graph.NewDoFn(
+ &VetSplitRestrictionErrSdf{},
+ graph.NumMainInputs(graph.MainSingle),
+ )
+ if err != nil {
+ t.Fatalf("invalid function: %v", err)
+ }
+ splitRestErrSdf := (*graph.SplittableDoFn)(splitRestErrDfn)
+
+ restSizeErrDfn, err := graph.NewDoFn(
+ &VetRestrictionSizeErrSdf{},
+ graph.NumMainInputs(graph.MainSingle),
+ )
+ if err != nil {
+ t.Fatalf("invalid function: %v", err)
+ }
+ restSizeErrSdf := (*graph.SplittableDoFn)(restSizeErrDfn)
+
+ trackerErrDfn, err := graph.NewDoFn(
+ &VetCreateTrackerErrSdf{},
+ graph.NumMainInputs(graph.MainSingle),
+ )
+ if err != nil {
+ t.Fatalf("invalid function: %v", err)
+ }
+ trackerErrSdf := (*graph.SplittableDoFn)(trackerErrDfn)
+
+ truncateRestErrDfn, err := graph.NewDoFn(
+ &VetTruncateRestrictionErrSdf{},
+ graph.NumMainInputs(graph.MainSingle),
+ )
+ if err != nil {
+ t.Fatalf("invalid function: %v", err)
+ }
+ truncateRestErrSdf := (*graph.SplittableDoFn)(truncateRestErrDfn)
+
// Tests.
t.Run("CreateInitialRestriction Invoker (cirInvoker)", func(t *testing.T) {
tests := []struct {
- name string
- sdf *graph.SplittableDoFn
- elms *FullValue
- want *VetRestriction
+ name string
+ sdf *graph.SplittableDoFn
+ elms *FullValue
+ want *VetRestriction
+ wantErr bool
}{
{
name: "SingleElem",
@@ -68,6 +116,12 @@ func TestInvokes(t *testing.T) {
elms: &FullValue{Elm: 1, Elm2: 2},
want: &VetRestriction{ID: "KvSdf", CreateRest: true, Key: 1, Val: 2},
},
+ {
+ name: "Error",
+ sdf: initialRestErrSdf,
+ elms: &FullValue{Elm: 1},
+ wantErr: true,
+ },
}
for _, test := range tests {
test := test
@@ -77,7 +131,15 @@ func TestInvokes(t *testing.T) {
if err != nil {
t.Fatalf("newCreateInitialRestrictionInvoker failed: %v", err)
}
- got := invoker.Invoke(test.elms)
+
+ got, err := invoker.Invoke(context.Background(), test.elms)
+ if (err != nil) != test.wantErr {
+ t.Fatalf("Invoke(%v) error = %v, wantErr %v", test.elms, err, test.wantErr)
+ }
+ if test.wantErr {
+ return
+ }
+
if !cmp.Equal(got, test.want) {
t.Errorf("Invoke(%v) has incorrect output: got: %v, want: %v",
test.elms, got, test.want)
@@ -94,11 +156,12 @@ func TestInvokes(t *testing.T) {
t.Run("SplitRestriction Invoker (srInvoker)", func(t *testing.T) {
tests := []struct {
- name string
- sdf *graph.SplittableDoFn
- elms *FullValue
- rest *VetRestriction
- want []any
+ name string
+ sdf *graph.SplittableDoFn
+ elms *FullValue
+ rest *VetRestriction
+ want []any
+ wantErr bool
}{
{
name: "SingleElem",
@@ -119,6 +182,13 @@ func TestInvokes(t *testing.T) {
&VetRestriction{ID: "KvSdf.2", SplitRest: true, Key: 1, Val: 2},
},
},
+ {
+ name: "Error",
+ sdf: splitRestErrSdf,
+ elms: &FullValue{Elm: 1},
+ rest: &VetRestriction{ID: "Sdf"},
+ wantErr: true,
+ },
}
for _, test := range tests {
test := test
@@ -128,8 +198,16 @@ func TestInvokes(t *testing.T) {
if err != nil {
t.Fatalf("newSplitRestrictionInvoker failed: %v", err)
}
+
rest := *test.rest // Create a copy because our test SDF edits the restriction.
- got := invoker.Invoke(test.elms, &rest)
+ got, err := invoker.Invoke(context.Background(), test.elms, &rest)
+ if (err != nil) != test.wantErr {
+ t.Fatalf("Invoke(%v, %v) error = %v, wantErr %v", test.elms, test.rest, err, test.wantErr)
+ }
+ if test.wantErr {
+ return
+ }
+
if !cmp.Equal(got, test.want) {
t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
test.elms, test.rest, got, test.want)
@@ -152,6 +230,7 @@ func TestInvokes(t *testing.T) {
rest *VetRestriction
want float64
restWant *VetRestriction
+ wantErr bool
}{
{
name: "SingleElem",
@@ -168,6 +247,13 @@ func TestInvokes(t *testing.T) {
want: 3,
restWant: &VetRestriction{ID: "KvSdf", RestSize: true, Key: 1, Val: 2},
},
+ {
+ name: "Error",
+ sdf: restSizeErrSdf,
+ elms: &FullValue{Elm: 1},
+ rest: &VetRestriction{ID: "Sdf"},
+ wantErr: true,
+ },
}
for _, test := range tests {
test := test
@@ -178,7 +264,15 @@ func TestInvokes(t *testing.T) {
t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
}
rest := *test.rest // Create a copy because our test SDF edits the restriction.
- got := invoker.Invoke(test.elms, &rest)
+
+ got, err := invoker.Invoke(context.Background(), test.elms, &rest)
+ if (err != nil) != test.wantErr {
+ t.Fatalf("Invoke(%v, %v) error = %v, wantErr %v", test.elms, test.rest, err, test.wantErr)
+ }
+ if test.wantErr {
+ return
+ }
+
if !cmp.Equal(got, test.want) {
t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
test.elms, test.rest, got, test.want)
@@ -199,10 +293,11 @@ func TestInvokes(t *testing.T) {
t.Run("CreateTracker Invoker (ctInvoker)", func(t *testing.T) {
tests := []struct {
- name string
- sdf *graph.SplittableDoFn
- rest *VetRestriction
- want *VetRTracker
+ name string
+ sdf *graph.SplittableDoFn
+ rest *VetRestriction
+ want *VetRTracker
+ wantErr bool
}{
{
name: "SingleElem",
@@ -215,6 +310,12 @@ func TestInvokes(t *testing.T) {
rest: &VetRestriction{ID: "KvSdf"},
want: &VetRTracker{&VetRestriction{ID: "KvSdf", CreateTracker: true}},
},
+ {
+ name: "Error",
+ sdf: trackerErrSdf,
+ rest: &VetRestriction{ID: "Sdf"},
+ wantErr: true,
+ },
}
for _, test := range tests {
test := test
@@ -224,7 +325,15 @@ func TestInvokes(t *testing.T) {
if err != nil {
t.Fatalf("newCreateTrackerInvoker failed: %v", err)
}
- got := invoker.Invoke(test.rest)
+
+ got, err := invoker.Invoke(context.Background(), test.rest)
+ if (err != nil) != test.wantErr {
+ t.Fatalf("Invoke(%v) error = %v, wantErr %v", test.rest, err, test.wantErr)
+ }
+ if test.wantErr {
+ return
+ }
+
if !cmp.Equal(got, test.want) {
t.Errorf("Invoke(%v) has incorrect output: got: %v, want: %v",
test.rest, got, test.want)
@@ -310,11 +419,12 @@ func TestInvokes(t *testing.T) {
t.Run("TruncateRestriction Invoker (trInvoker)", func(t *testing.T) {
tests := []struct {
- name string
- sdf *graph.SplittableDoFn
- elms *FullValue
- rest *VetRestriction
- want any
+ name string
+ sdf *graph.SplittableDoFn
+ elms *FullValue
+ rest *VetRestriction
+ want any
+ wantErr bool
}{
{
name: "SingleElem",
@@ -329,6 +439,13 @@ func TestInvokes(t *testing.T) {
rest: &VetRestriction{ID: "KvSdf"},
want: &VetRestriction{ID: "KvSdf", CreateTracker: true, TruncateRest: true, RestSize: true, Key: 1, Val: 2},
},
+ {
+ name: "Error",
+ sdf: truncateRestErrSdf,
+ elms: &FullValue{Elm: 1},
+ rest: &VetRestriction{ID: "Sdf"},
+ wantErr: true,
+ },
}
for _, test := range tests {
test := test
@@ -336,24 +453,38 @@ func TestInvokes(t *testing.T) {
ctFn := test.sdf.CreateTrackerFn()
rsFn := test.sdf.RestrictionSizeFn()
t.Run(test.name, func(t *testing.T) {
+ ctx := context.Background()
rest := test.rest // Create a copy because our test SDF edits the restriction.
ctInvoker, err := newCreateTrackerInvoker(ctFn)
if err != nil {
t.Fatalf("newCreateTrackerInvoker failed: %v", err)
}
- rt := ctInvoker.Invoke(rest)
+ rt, err := ctInvoker.Invoke(ctx, rest)
+ if err != nil {
+ t.Fatalf("ctInvoker.Invoke(%v) failed: %v", rest, err)
+ }
trInvoker, err := newTruncateRestrictionInvoker(fn)
if err != nil {
t.Fatalf("newTruncateRestrictionInvoker failed: %v", err)
}
- trRest := trInvoker.Invoke(rt, test.elms)
+
+ trRest, err := trInvoker.Invoke(ctx, rt, test.elms)
+ if (err != nil) != test.wantErr {
+ t.Fatalf("trInvoker.Invoke(%v, %v) = %v, wantErr %v", rt, test.elms, err, test.wantErr)
+ }
+ if test.wantErr {
+ return
+ }
rsInvoker, err := newRestrictionSizeInvoker(rsFn)
if err != nil {
t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
}
- _ = rsInvoker.Invoke(test.elms, trRest)
+ if _, err := rsInvoker.Invoke(ctx, test.elms, trRest); err != nil {
+ t.Fatalf("rsInvoker.Invoke(%v, %v) failed: %v", test.elms, trRest, err)
+ }
+
if !cmp.Equal(trRest, test.want) {
t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
test.elms, test.rest, trRest, test.want)
@@ -411,24 +542,35 @@ func TestInvokes(t *testing.T) {
ctFn := test.sdf.CreateTrackerFn()
rsFn := test.sdf.RestrictionSizeFn()
t.Run(test.name, func(t *testing.T) {
+ ctx := context.Background()
rest := test.rest // Create a copy because our test SDF edits the restriction.
ctInvoker, err := newCreateTrackerInvoker(ctFn)
if err != nil {
t.Fatalf("newCreateTrackerInvoker failed: %v", err)
}
- rt := ctInvoker.Invoke(rest)
+ rt, err := ctInvoker.Invoke(ctx, rest)
+ if err != nil {
+ t.Fatalf("ctInvoker.Invoke(%v) failed: %v", rest, err)
+ }
trInvoker, err := newDefaultTruncateRestrictionInvoker()
if err != nil {
t.Fatalf("newTruncateRestrictionInvoker failed: %v", err)
}
- trRest := trInvoker.Invoke(rt, test.elms)
+ trRest, err := trInvoker.Invoke(ctx, rt, test.elms)
+ if err != nil {
+ t.Fatalf("trInvoker.Invoke(%v, %v) failed: %v", rt, test.elms, err)
+ }
+
if trRest != nil {
rsInvoker, err := newRestrictionSizeInvoker(rsFn)
if err != nil {
t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
}
- _ = rsInvoker.Invoke(test.elms, trRest)
+ if _, err := rsInvoker.Invoke(ctx, test.elms, trRest); err != nil {
+ t.Fatalf("rsInvoker.Invoke(%v, %v) failed: %v", test.elms, trRest, err)
+ }
+
if !cmp.Equal(trRest, test.want) {
t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
test.elms, test.rest, trRest, test.want)
@@ -732,3 +874,55 @@ func (fn *VetEmptyInitialSplitSdf) ProcessElement(rt *VetRTracker, i int, emit f
emit(rest)
return sdf.ResumeProcessingIn(1 * time.Second)
}
+
+var errSdf = errors.New("SDF error")
+
+// VetCreateInitialRestrictionErrSdf is an SDF with a CreateInitialRestriction method
+// that returns a non-nil error.
+type VetCreateInitialRestrictionErrSdf struct {
+ VetSdf
+}
+
+func (fn *VetCreateInitialRestrictionErrSdf) CreateInitialRestriction(i int) (*VetRestriction, error) {
+ return nil, errSdf
+}
+
+// VetSplitRestrictionErrSdf is an SDF with a SplitRestriction method
+// that returns a non-nil error.
+type VetSplitRestrictionErrSdf struct {
+ VetSdf
+}
+
+func (fn *VetSplitRestrictionErrSdf) SplitRestriction(int, *VetRestriction) ([]*VetRestriction, error) {
+ return nil, errSdf
+}
+
+// VetRestrictionSizeErrSdf is an SDF with a RestrictionSize method
+// that returns a non-nil error.
+type VetRestrictionSizeErrSdf struct {
+ VetSdf
+}
+
+func (fn *VetRestrictionSizeErrSdf) RestrictionSize(int, *VetRestriction) (float64, error) {
+ return -1, errSdf
+}
+
+// VetCreateTrackerErrSdf is an SDF with a CreateTracker method
+// that returns a non-nil error.
+type VetCreateTrackerErrSdf struct {
+ VetSdf
+}
+
+func (fn *VetCreateTrackerErrSdf) CreateTracker(*VetRestriction) (*VetRTracker, error) {
+ return nil, errSdf
+}
+
+// VetTruncateRestrictionErrSdf is an SDF with a TruncateRestriction method
+// that returns a non-nil error.
+type VetTruncateRestrictionErrSdf struct {
+ VetSdf
+}
+
+func (fn *VetTruncateRestrictionErrSdf) TruncateRestriction(*VetRTracker, int) (*VetRestriction, error) {
+ return nil, errSdf
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index 414f28553a8..a0380796e86 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -1114,10 +1114,11 @@ func TestAsSplittableUnit(t *testing.T) {
// Call from SplittableUnit and check results.
su := SplittableUnit(node)
- if err := node.Up(context.Background()); err != nil {
+ ctx := context.Background()
+ if err := node.Up(ctx); err != nil {
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
}
- gotPrimaries, gotResiduals, err := su.Split(test.frac)
+ gotPrimaries, gotResiduals, err := su.Split(ctx, test.frac)
if err != nil {
t.Fatalf("SplittableUnit.Split(%v) failed: %v", test.frac, err)
}
@@ -1184,10 +1185,11 @@ func TestAsSplittableUnit(t *testing.T) {
// Call from SplittableUnit and check results.
su := SplittableUnit(node)
- if err := node.Up(context.Background()); err != nil {
+ ctx := context.Background()
+ if err := node.Up(ctx); err != nil {
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
}
- _, _, err := su.Split(0.5)
+ _, _, err := su.Split(ctx, 0.5)
if err == nil {
t.Errorf("SplittableUnit.Split(%v) was expected to fail.", test.in)
}
@@ -1251,10 +1253,11 @@ func TestAsSplittableUnit(t *testing.T) {
node.currW = 0
// Call from SplittableUnit and check results.
su := SplittableUnit(node)
- if err := node.Up(context.Background()); err != nil {
+ ctx := context.Background()
+ if err := node.Up(ctx); err != nil {
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
}
- gotResiduals, err := su.Checkpoint()
+ gotResiduals, err := su.Checkpoint(ctx)
if err != nil {
t.Fatalf("SplittableUnit.Checkpoint() returned error, got %v", err)
@@ -1401,7 +1404,7 @@ func TestMultiWindowProcessing(t *testing.T) {
// Split should hit window boundary between 2 and 3. We don't need to check
// the split result here, just the effects it has on currW and numW.
frac := 0.5
- if _, _, err := su.Split(frac); err != nil {
+ if _, _, err := su.Split(context.Background(), frac); err != nil {
t.Errorf("Split(%v) failed with error: %v", frac, err)
}
if got, want := node.currW, blockW; got != want {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/config/config.go b/sdks/go/pkg/beam/runners/prism/internal/config/config.go
index fc2b68d092f..9c3bdd012bc 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/config/config.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/config/config.go
@@ -244,4 +244,4 @@ func (r *HandlerRegistry) GetVariant(name string) *Variant {
return nil
}
return &Variant{parent: r, name: name, handlers: vs.Handlers}
-}
\ No newline at end of file
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
index 3c7cae97397..7b553f6ad65 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
@@ -33,4 +33,4 @@ func Test_toUrn(t *testing.T) {
if got := quickUrn(pipepb.StandardPTransforms_PAR_DO); got != want {
t.Errorf("quickUrn(\"pipepb.StandardPTransforms_PAR_DO\") = %v, want %v", got, want)
}
-}
\ No newline at end of file
+}