You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2023/02/16 20:02:22 UTC

[beam] branch master updated: [Go SDK]: Allow SDF methods to have context param and error return value (#25437)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 29ea6e0eb8a [Go SDK]: Allow SDF methods to have context param and error return value (#25437)
29ea6e0eb8a is described below

commit 29ea6e0eb8a2f2cf571d1a27799d125ca051008b
Author: Johanna Öjeling <51...@users.noreply.github.com>
AuthorDate: Thu Feb 16 21:02:08 2023 +0100

    [Go SDK]: Allow SDF methods to have context param and error return value (#25437)
    
    * Allow context param and error return value in SDF validation
    
    * Use context param and error return value in SDF method invocation
    
    * Run go fmt
    
    * Clean up error messages from "fn reflect.methodValueCall"
    
    * Validate return value count in a more correct way
---
 sdks/go/pkg/beam/core/graph/fn.go                  | 108 ++++---
 sdks/go/pkg/beam/core/graph/fn_test.go             |  87 ++++++
 sdks/go/pkg/beam/core/runtime/exec/datasource.go   |   8 +-
 .../pkg/beam/core/runtime/exec/datasource_test.go  |  12 +-
 sdks/go/pkg/beam/core/runtime/exec/sdf.go          | 134 +++++---
 sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go | 339 +++++++++------------
 .../beam/core/runtime/exec/sdf_invokers_arity.go   | 336 ++++++++++++++++++++
 .../beam/core/runtime/exec/sdf_invokers_arity.tmpl | 246 +++++++++++++++
 .../beam/core/runtime/exec/sdf_invokers_test.go    | 250 +++++++++++++--
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     |  17 +-
 .../beam/runners/prism/internal/config/config.go   |   2 +-
 .../beam/runners/prism/internal/urns/urns_test.go  |   2 +-
 12 files changed, 1227 insertions(+), 314 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 907af4045b5..25b846370fb 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -866,14 +866,15 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error {
 	// CreateInitialRestriction.
 	if numMainIn == MainUnknown {
 		initialRestFn := fn.methods[createInitialRestrictionName]
-		paramNum := len(initialRestFn.Param)
+		paramNum := len(initialRestFn.Params(funcx.FnValue))
+
 		switch paramNum {
 		case int(MainSingle), int(MainKv):
 			num = paramNum
 		default: // Can't infer because method has invalid # of main inputs.
-			err := errors.Errorf("invalid number of params in method %v. got: %v, want: %v or %v",
+			err := errors.Errorf("invalid number of main input params in method %v. got: %v, want: %v or %v",
 				createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv))
-			return errors.SetTopLevelMsgf(err, "Invalid number of parameters in method %v. "+
+			return errors.SetTopLevelMsgf(err, "Invalid number of main input parameters in method %v. "+
 				"Got: %v, Want: %v or %v. Check that the signature conforms to the expected signature for %v, "+
 				"and that elements in SDF method parameters match elements in %v.",
 				createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv), createInitialRestrictionName, processElementName)
@@ -894,7 +895,7 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error {
 // in each SDF method in the given Fn, and returns an error if a method has an
 // invalid/unexpected number.
 func validateSdfSigNumbers(fn *Fn, num int) error {
-	paramNums := map[string]int{
+	reqParamNums := map[string]int{
 		createInitialRestrictionName: num,
 		splitRestrictionName:         num + 1,
 		restrictionSizeName:          num + 1,
@@ -904,32 +905,52 @@ func validateSdfSigNumbers(fn *Fn, num int) error {
 	optionalSdfs := map[string]bool{
 		truncateRestrictionName: true,
 	}
-	returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF methods.
+	reqReturnNum := 1
 
 	for _, name := range sdfNames {
 		method, ok := fn.methods[name]
 		if !ok && optionalSdfs[name] {
 			continue
 		}
-		if len(method.Param) != paramNums[name] {
-			err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v",
-				name, len(method.Param), paramNums[name])
+
+		reqParamNum := reqParamNums[name]
+		if !sdfHasValidParamNum(method.Param, reqParamNum) {
+			err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v or optionally %v "+
+				"if first param is of type context.Context", name, len(method.Param), reqParamNum, reqParamNum+1)
 			return errors.SetTopLevelMsgf(err, "Unexpected number of parameters in method %v. "+
-				"Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v, "+
-				"and that elements in SDF method parameters match elements in %v.",
-				name, len(method.Param), paramNums[name], name, processElementName)
+				"Got: %v, Want: %v or optionally %v if first param is of type context.Context. "+
+				"Check that the signature conforms to the expected signature for %v, and that elements in SDF method "+
+				"parameters match elements in %v.", name, len(method.Param), reqParamNum, reqParamNum+1,
+				name, processElementName)
 		}
-		if len(method.Ret) != returnNum {
-			err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v",
-				name, len(method.Ret), returnNum)
+		if !sdfHasValidReturnNum(method.Ret, reqReturnNum) {
+			err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v or optionally %v "+
+				"if last value is of type error", name, len(method.Ret), reqReturnNum, reqReturnNum+1)
 			return errors.SetTopLevelMsgf(err, "Unexpected number of return values in method %v. "+
-				"Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v.",
-				name, len(method.Ret), returnNum, name)
+				"Got: %v, Want: %v or optionally %v if last value is of type error. "+
+				"Check that the signature conforms to the expected signature for %v.",
+				name, len(method.Ret), reqReturnNum, reqReturnNum+1, name)
 		}
 	}
 	return nil
 }
 
+func sdfHasValidParamNum(params []funcx.FnParam, requiredNum int) bool {
+	if len(params) == requiredNum {
+		return true
+	}
+
+	return len(params) == requiredNum+1 && params[0].Kind == funcx.FnContext
+}
+
+func sdfHasValidReturnNum(returns []funcx.ReturnParam, requiredNum int) bool {
+	if len(returns) == requiredNum {
+		return true
+	}
+
+	return len(returns) == requiredNum+1 && returns[len(returns)-1].Kind == funcx.RetError
+}
+
 // validateSdfSigTypes validates the types of the parameters and return values
 // in each SDF method in the given Fn, and returns an error if a method has an
 // invalid/mismatched type. Assumes that the number of parameters and return
@@ -940,22 +961,25 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 
 	for _, name := range requiredSdfNames {
 		method := fn.methods[name]
+		startIdx := sdfRequiredParamStartIndex(method)
+
 		switch name {
 		case createInitialRestrictionName:
-			if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, 0); err != nil {
+			if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, startIdx); err != nil {
 				return err
 			}
 		case splitRestrictionName:
-			if err := validateSdfElementT(fn, splitRestrictionName, method, num, 0); err != nil {
+			if err := validateSdfElementT(fn, splitRestrictionName, method, num, startIdx); err != nil {
 				return err
 			}
-			if method.Param[num].T != restrictionT {
+			idx := num + startIdx
+			if method.Param[idx].T != restrictionT {
 				err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
-					splitRestrictionName, num, method.Param[num].T, restrictionT)
+					splitRestrictionName, idx, method.Param[idx].T, restrictionT)
 				return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
 					"parameter at index %v. Got: %v, Want: %v (from method %v). "+
 					"Ensure that all restrictions in an SDF are the same type.",
-					splitRestrictionName, num, method.Param[num].T, restrictionT, createInitialRestrictionName)
+					splitRestrictionName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName)
 			}
 			if method.Ret[0].T.Kind() != reflect.Slice ||
 				method.Ret[0].T.Elem() != restrictionT {
@@ -967,16 +991,17 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 					splitRestrictionName, 0, method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, splitRestrictionName)
 			}
 		case restrictionSizeName:
-			if err := validateSdfElementT(fn, restrictionSizeName, method, num, 0); err != nil {
+			if err := validateSdfElementT(fn, restrictionSizeName, method, num, startIdx); err != nil {
 				return err
 			}
-			if method.Param[num].T != restrictionT {
+			idx := num + startIdx
+			if method.Param[idx].T != restrictionT {
 				err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
-					restrictionSizeName, num, method.Param[num].T, restrictionT)
+					restrictionSizeName, idx, method.Param[idx].T, restrictionT)
 				return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
 					"parameter at index %v. Got: %v, Want: %v (from method %v). "+
 					"Ensure that all restrictions in an SDF are the same type.",
-					restrictionSizeName, num, method.Param[num].T, restrictionT, createInitialRestrictionName)
+					restrictionSizeName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName)
 			}
 			if method.Ret[0].T != reflectx.Float64 {
 				err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
@@ -986,13 +1011,13 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 					restrictionSizeName, 0, method.Ret[0].T, reflectx.Float64)
 			}
 		case createTrackerName:
-			if method.Param[0].T != restrictionT {
+			if method.Param[startIdx].T != restrictionT {
 				err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
-					createTrackerName, 0, method.Param[0].T, restrictionT)
+					createTrackerName, startIdx, method.Param[startIdx].T, restrictionT)
 				return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+
 					"parameter at index %v. Got: %v, Want: %v (from method %v). "+
 					"Ensure that all restrictions in an SDF are the same type.",
-					createTrackerName, 0, method.Param[0].T, restrictionT, createInitialRestrictionName)
+					createTrackerName, startIdx, method.Param[startIdx].T, restrictionT, createInitialRestrictionName)
 			}
 			if !method.Ret[0].T.Implements(rTrackerT) {
 				err := errors.Errorf("invalid output type in method %v, return %v: %v does not implement sdf.RTracker",
@@ -1020,15 +1045,18 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 		if !ok {
 			continue
 		}
+
+		startIdx := sdfRequiredParamStartIndex(method)
+
 		switch name {
 		case truncateRestrictionName:
-			if method.Param[0].T != rTrackerImplT {
+			if method.Param[startIdx].T != rTrackerImplT {
 				err := errors.Errorf("mismatched restriction tracker type in method %v, param %v. got: %v, want: %v",
-					truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT)
+					truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT)
 				return errors.SetTopLevelMsgf(err, "Mismatched restriction tracker type in method %v, "+
 					"parameter at index %v. Got: %v, Want: %v (from method %v). "+
 					"Ensure that restriction tracker is the first parameter.",
-					truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT, createTrackerName)
+					truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT, createTrackerName)
 			}
 			if method.Ret[0].T != restrictionT {
 				err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
@@ -1052,6 +1080,14 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 	return nil
 }
 
+func sdfRequiredParamStartIndex(method *funcx.Fn) int {
+	if ctxIndex, ok := method.Context(); ok {
+		return ctxIndex + 1
+	}
+
+	return 0
+}
+
 // validateSdfElementT validates that element types in an SDF method are
 // consistent with the ProcessElement method. This method assumes that the
 // first 'num' parameters starting with startIndex are the elements.
@@ -1062,13 +1098,14 @@ func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int, startIn
 	pos, _, _ := processFn.Inputs()
 
 	for i := 0; i < num; i++ {
-		if method.Param[i+startIndex].T != processFn.Param[pos+i].T {
+		idx := i + startIndex
+		if method.Param[idx].T != processFn.Param[pos+i].T {
 			err := errors.Errorf("mismatched element type in method %v, param %v. got: %v, want: %v",
-				name, i, method.Param[i].T, processFn.Param[pos+i].T)
+				name, idx, method.Param[idx].T, processFn.Param[pos+i].T)
 			return errors.SetTopLevelMsgf(err, "Mismatched element type in method %v, "+
 				"parameter at index %v. Got: %v, Want: %v (from method %v). "+
 				"Ensure that element parameters in SDF methods have consistent types with element parameters in %v.",
-				name, i, method.Param[i].T, processFn.Param[pos+i].T, processElementName, processElementName)
+				name, idx, method.Param[idx].T, processFn.Param[pos+i].T, processElementName, processElementName)
 		}
 	}
 	return nil
@@ -1178,7 +1215,8 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
 	// CreateInitialRestriction.
 	if numMainIn == int(MainUnknown) {
 		initialRestFn := fn.methods[createInitialRestrictionName]
-		paramNum := len(initialRestFn.Param)
+		paramNum := len(initialRestFn.Params(funcx.FnValue))
+
 		switch paramNum {
 		case int(MainSingle), int(MainKv):
 			numMainIn = paramNum
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index d2f88a8a5ce..cf44761d4f3 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -190,6 +190,9 @@ func TestNewDoFnSdf(t *testing.T) {
 		}{
 			{dfn: &GoodSdf{}, main: MainSingle},
 			{dfn: &GoodSdfKv{}, main: MainKv},
+			{dfn: &GoodSdfWContext{}, main: MainSingle},
+			{dfn: &GoodSdfKvWContext{}, main: MainKv},
+			{dfn: &GoodSdfWErr{}, main: MainSingle},
 			{dfn: &GoodIgnoreOtherExportedMethods{}, main: MainSingle},
 		}
 
@@ -987,6 +990,90 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT {
 	return RestT{}
 }
 
+type GoodSdfWContext struct {
+	*GoodDoFn
+}
+
+func (fn *GoodSdfWContext) CreateInitialRestriction(context.Context, int) RestT {
+	return RestT{}
+}
+
+func (fn *GoodSdfWContext) SplitRestriction(context.Context, int, RestT) []RestT {
+	return []RestT{}
+}
+
+func (fn *GoodSdfWContext) RestrictionSize(context.Context, int, RestT) float64 {
+	return 0
+}
+
+func (fn *GoodSdfWContext) CreateTracker(context.Context, RestT) *RTrackerT {
+	return &RTrackerT{}
+}
+
+func (fn *GoodSdfWContext) ProcessElement(context.Context, *RTrackerT, int) (int, sdf.ProcessContinuation) {
+	return 0, sdf.StopProcessing()
+}
+
+func (fn *GoodSdfWContext) TruncateRestriction(context.Context, *RTrackerT, int) RestT {
+	return RestT{}
+}
+
+type GoodSdfKvWContext struct {
+	*GoodDoFnKv
+}
+
+func (fn *GoodSdfKvWContext) CreateInitialRestriction(context.Context, int, int) RestT {
+	return RestT{}
+}
+
+func (fn *GoodSdfKvWContext) SplitRestriction(context.Context, int, int, RestT) []RestT {
+	return []RestT{}
+}
+
+func (fn *GoodSdfKvWContext) RestrictionSize(context.Context, int, int, RestT) float64 {
+	return 0
+}
+
+func (fn *GoodSdfKvWContext) CreateTracker(context.Context, RestT) *RTrackerT {
+	return &RTrackerT{}
+}
+
+func (fn *GoodSdfKvWContext) ProcessElement(context.Context, *RTrackerT, int, int) (int, sdf.ProcessContinuation) {
+	return 0, sdf.StopProcessing()
+}
+
+func (fn *GoodSdfKvWContext) TruncateRestriction(context.Context, *RTrackerT, int, int) RestT {
+	return RestT{}
+}
+
+type GoodSdfWErr struct {
+	*GoodDoFn
+}
+
+func (fn *GoodSdfWErr) CreateInitialRestriction(int) (RestT, error) {
+	return RestT{}, nil
+}
+
+func (fn *GoodSdfWErr) SplitRestriction(int, RestT) ([]RestT, error) {
+	return []RestT{}, nil
+}
+
+func (fn *GoodSdfWErr) RestrictionSize(int, RestT) (float64, error) {
+	return 0, nil
+}
+
+func (fn *GoodSdfWErr) CreateTracker(RestT) (*RTrackerT, error) {
+	return &RTrackerT{}, nil
+}
+
+func (fn *GoodSdfWErr) ProcessElement(*RTrackerT, int) (int, sdf.ProcessContinuation, error) {
+	return 0, sdf.StopProcessing(), nil
+}
+
+func (fn *GoodSdfWErr) TruncateRestriction(*RTrackerT, int) (RestT, error) {
+	return RestT{}, nil
+}
+
 type GoodIgnoreOtherExportedMethods struct {
 	*GoodSdf
 }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
index 9c4de0564c8..a6347fc8d0e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
@@ -196,7 +196,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) {
 		// Check if there's a continuation and return residuals
 		// Needs to be done immeadiately after processing to not lose the element.
 		if c := n.getProcessContinuation(); c != nil {
-			cp, err := n.checkpointThis(c)
+			cp, err := n.checkpointThis(ctx, c)
 			if err != nil {
 				// Errors during checkpointing should fail a bundle.
 				return nil, err
@@ -422,7 +422,7 @@ type Checkpoint struct {
 // splittable or has not returned a resuming continuation, the function returns an empty
 // SplitResult, a negative resumption time, and a false boolean to indicate that no split
 // occurred.
-func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, error) {
+func (n *DataSource) checkpointThis(ctx context.Context, pc sdf.ProcessContinuation) (*Checkpoint, error) {
 	n.mu.Lock()
 	defer n.mu.Unlock()
 
@@ -435,7 +435,7 @@ func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, er
 	ow := su.GetOutputWatermark()
 
 	// Checkpointing is functionally a split at fraction 0.0
-	rs, err := su.Checkpoint()
+	rs, err := su.Checkpoint(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -530,7 +530,7 @@ func (n *DataSource) Split(ctx context.Context, splits []int64, frac float64, bu
 	// Get the output watermark before splitting to avoid accidentally overestimating
 	ow := su.GetOutputWatermark()
 	// Otherwise, perform a sub-element split.
-	ps, rs, err := su.Split(fr)
+	ps, rs, err := su.Split(ctx, fr)
 	if err != nil {
 		return SplitResult{}, err
 	}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
index 64a37739b24..2da3284f016 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
@@ -631,7 +631,7 @@ type TestSplittableUnit struct {
 
 // Split checks the input fraction for correctness, but otherwise always returns
 // a successful split. The split elements are just copies of the original.
-func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *TestSplittableUnit) Split(_ context.Context, f float64) ([]*FullValue, []*FullValue, error) {
 	if f > 1.0 || f < 0.0 {
 		return nil, nil, errors.Errorf("Error")
 	}
@@ -639,8 +639,8 @@ func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, error
 }
 
 // Checkpoint routes through the Split() function to satisfy the interface.
-func (n *TestSplittableUnit) Checkpoint() ([]*FullValue, error) {
-	_, r, err := n.Split(0.0)
+func (n *TestSplittableUnit) Checkpoint(ctx context.Context) ([]*FullValue, error) {
+	_, r, err := n.Split(ctx, 0.0)
 	return r, err
 }
 
@@ -876,13 +876,13 @@ func TestSplitHelper(t *testing.T) {
 
 func TestCheckpointing(t *testing.T) {
 	t.Run("nil", func(t *testing.T) {
-		cps, err := (&DataSource{}).checkpointThis(nil)
+		cps, err := (&DataSource{}).checkpointThis(context.Background(), nil)
 		if err != nil {
 			t.Fatalf("checkpointThis() = %v, %v", cps, err)
 		}
 	})
 	t.Run("Stop", func(t *testing.T) {
-		cps, err := (&DataSource{}).checkpointThis(sdf.StopProcessing())
+		cps, err := (&DataSource{}).checkpointThis(context.Background(), sdf.StopProcessing())
 		if err != nil {
 			t.Fatalf("checkpointThis() = %v, %v", cps, err)
 		}
@@ -899,7 +899,7 @@ func TestCheckpointing(t *testing.T) {
 				},
 			},
 		}
-		cp, err := root.checkpointThis(sdf.ResumeProcessingIn(time.Second * 13))
+		cp, err := root.checkpointThis(context.Background(), sdf.ResumeProcessingIn(time.Second*13))
 		if err != nil {
 			t.Fatalf("checkpointThis() = %v, %v, want nil", cp, err)
 		}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index e22496eae6e..1dd3e35dc4d 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -90,7 +90,11 @@ func (n *PairWithRestriction) StartBundle(ctx context.Context, id string, data D
 //	  Timestamps
 //	}
 func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
-	rest := n.inv.Invoke(elm)
+	rest, err := n.inv.Invoke(ctx, elm)
+	if err != nil {
+		return err
+	}
+
 	output := FullValue{Elm: elm, Elm2: &FullValue{Elm: rest, Elm2: n.iwesInv.Invoke(rest, elm)}, Timestamp: elm.Timestamp, Windows: elm.Windows}
 
 	return n.Out.ProcessElement(ctx, &output, values...)
@@ -195,10 +199,17 @@ func (n *SplitAndSizeRestrictions) ProcessElement(ctx context.Context, elm *Full
 	// the element may not be wrapped in a *FullValue
 	mainElm := convertIfNeeded(elm.Elm, &FullValue{})
 
-	splitRests := n.splitInv.Invoke(mainElm, rest)
+	splitRests, err := n.splitInv.Invoke(ctx, mainElm, rest)
+	if err != nil {
+		return err
+	}
 
 	for _, splitRest := range splitRests {
-		size := n.sizeInv.Invoke(mainElm, splitRest)
+		size, err := n.sizeInv.Invoke(ctx, mainElm, splitRest)
+		if err != nil {
+			return err
+		}
+
 		if size < 0 {
 			err := errors.Errorf("size returned expected to be non-negative but received %v.", size)
 			return errors.WithContextf(err, "%v", n)
@@ -325,13 +336,25 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *Full
 		inp = e
 	}
 	rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-	rt := n.ctInv.Invoke(rest)
-	newRest := n.truncateInv.Invoke(rt, mainElm)
+
+	rt, err := n.ctInv.Invoke(ctx, rest)
+	if err != nil {
+		return err
+	}
+
+	newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm)
+	if err != nil {
+		return err
+	}
 	if newRest == nil {
 		// do not propagate discarded restrictions.
 		return nil
 	}
-	size := n.sizeInv.Invoke(mainElm, newRest)
+
+	size, err := n.sizeInv.Invoke(ctx, mainElm, newRest)
+	if err != nil {
+		return err
+	}
 
 	output := &FullValue{}
 	output.Timestamp = elm.Timestamp
@@ -476,7 +499,7 @@ func (n *ProcessSizedElementsAndRestrictions) StartBundle(ctx context.Context, i
 // and processes each element using the underlying ParDo and adding the
 // restriction tracker to the normal invocation. Sizing information is present
 // but currently ignored. Output is forwarded to the underlying ParDo's outputs.
-func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context, elm *FullValue, values ...ReStream) error {
+func (n *ProcessSizedElementsAndRestrictions) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
 	if n.PDo.status != Active {
 		err := errors.Errorf("invalid status %v, want Active", n.PDo.status)
 		return errors.WithContextf(err, "%v", n)
@@ -520,7 +543,12 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
 		// If windows don't need to be exploded (i.e. aren't observed), treat
 		// all windows as one as an optimization.
 		rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-		rt := n.ctInv.Invoke(rest)
+
+		rt, err := n.ctInv.Invoke(ctx, rest)
+		if err != nil {
+			return err
+		}
+
 		mainIn.RTracker = rt
 
 		n.numW = 1 // Even if there's more than one window, treat them as one.
@@ -542,7 +570,12 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
 
 		for i := 0; i < n.numW; i++ {
 			rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-			rt := n.ctInv.Invoke(rest)
+
+			rt, err := n.ctInv.Invoke(ctx, rest)
+			if err != nil {
+				return err
+			}
+
 			key := &mainIn.Key
 			w := elm.Windows[i]
 			wElm := FullValue{Elm: key.Elm, Elm2: key.Elm2, Timestamp: key.Timestamp, Windows: []typex.Window{w}}
@@ -552,7 +585,7 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
 			n.elm = elm
 			n.SU <- n
 			// TODO(BEAM-11104): Remove placeholder for ProcessContinuation return.
-			_, err := n.PDo.processSingleWindow(&MainInput{Key: wElm, Values: mainIn.Values, RTracker: rt})
+			_, err = n.PDo.processSingleWindow(&MainInput{Key: wElm, Values: mainIn.Values, RTracker: rt})
 			if err != nil {
 				<-n.SU
 				return n.PDo.fail(err)
@@ -596,13 +629,13 @@ type SplittableUnit interface {
 	//
 	// More than one primary/residual can happen if the split result cannot be
 	// fully represented in just one.
-	Split(fraction float64) (primaries, residuals []*FullValue, err error)
+	Split(ctx context.Context, fraction float64) (primaries, residuals []*FullValue, err error)
 
 	// Checkpoint performs a split at fraction 0.0 of an element that has stopped
 	// processing and has work that needs to be resumed later. This function will
 	// check that the produced primary restriction from the split represents
 	// completed work to avoid data loss and will error if work remains.
-	Checkpoint() (residuals []*FullValue, err error)
+	Checkpoint(ctx context.Context) (residuals []*FullValue, err error)
 
 	// GetProgress returns the fraction of progress the current element has
 	// made in processing. (ex. 0.0 means no progress, and 1.0 means fully
@@ -631,7 +664,7 @@ type SplittableUnit interface {
 // windows need to be taken into account. For implementation details on when
 // each case occurs and the implementation details, see the documentation for
 // the singleWindowSplit and multiWindowSplit methods.
-func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) Split(ctx context.Context, f float64) ([]*FullValue, []*FullValue, error) {
 	// Get the watermark state immediately so that we don't overestimate our current watermark.
 	var pWeState any
 	var rWeState any
@@ -658,7 +691,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
 	// Split behavior differs depending on whether this is a window-observing
 	// DoFn or not.
 	if len(n.elm.Windows) > 1 {
-		p, r, err := n.multiWindowSplit(f, pWeState, rWeState)
+		p, r, err := n.multiWindowSplit(ctx, f, pWeState, rWeState)
 		if err != nil {
 			return nil, nil, addContext(err)
 		}
@@ -666,7 +699,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
 	}
 
 	// Not window-observing, or window-observing but only one window.
-	p, r, err := n.singleWindowSplit(f, pWeState, rWeState)
+	p, r, err := n.singleWindowSplit(ctx, f, pWeState, rWeState)
 	if err != nil {
 		return nil, nil, addContext(err)
 	}
@@ -677,11 +710,11 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
 // later by the runner. This is done iff the underlying Splittable DoFn returns a resuming
 // ProcessContinuation. If the split occurs and the primary restriction is marked as done
 // my the RTracker, the Checkpoint fails as this is a potential data-loss case.
-func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) Checkpoint(ctx context.Context) ([]*FullValue, error) {
 	addContext := func(err error) error {
 		return errors.WithContext(err, "Attempting checkpoint in ProcessSizedElementsAndRestrictions")
 	}
-	_, r, err := n.Split(0.0)
+	_, r, err := n.Split(ctx, 0.0)
 
 	if err != nil {
 		return nil, addContext(err)
@@ -699,7 +732,7 @@ func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, error)
 // behavior is identical). A single restriction split will occur and all windows
 // present in the unsplit element will be present in both the resulting primary
 // and residual.
-func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeState, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(ctx context.Context, f float64, pWeState, rWeState any) ([]*FullValue, []*FullValue, error) {
 	if n.rt.IsDone() { // Not an error, but not splittable.
 		return []*FullValue{}, []*FullValue{}, nil
 	}
@@ -714,14 +747,14 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeSt
 
 	var primaryResult []*FullValue
 	if p != nil {
-		pfv, err := n.newSplitResult(p, n.elm.Windows, pWeState)
+		pfv, err := n.newSplitResult(ctx, p, n.elm.Windows, pWeState)
 		if err != nil {
 			return nil, nil, err
 		}
 		primaryResult = append(primaryResult, pfv)
 	}
 
-	rfv, err := n.newSplitResult(r, n.elm.Windows, rWeState)
+	rfv, err := n.newSplitResult(ctx, r, n.elm.Windows, rWeState)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -752,7 +785,7 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeSt
 //
 // This method also updates the current number of windows (n.numW) so that
 // windows in the residual will no longer be processed.
-func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(ctx context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
 	// Get the split point in window range, to see what window it falls in.
 	done, rem := n.rt.GetProgress()
 	cwp := done / (done + rem)                      // Progress in current window.
@@ -765,25 +798,25 @@ func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeSta
 		if n.rt.IsDone() {
 			// Current RTracker is done so we can't split within the window, so
 			// split at window boundary instead.
-			return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
+			return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState)
 		}
 
 		// Get the fraction of remaining work in the current window to split at.
 		cwsp := wsp - float64(n.currW) // Split point in current window.
 		rf := (cwsp - cwp) / (1 - cwp) // Fraction of work in RTracker to split at.
 
-		return n.currentWindowSplit(rf, pWeState, rWeState)
+		return n.currentWindowSplit(ctx, rf, pWeState, rWeState)
 	} else {
 		// Split at nearest window boundary to split point.
 		wb := math.Round(wsp)
-		return n.windowBoundarySplit(int(wb), pWeState, rWeState)
+		return n.windowBoundarySplit(ctx, int(wb), pWeState, rWeState)
 	}
 }
 
 // currentWindowSplit performs an appropriate split at the given fraction of
 // remaining work in the current window. Also updates numW to stop after the
 // current window.
-func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(ctx context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
 	p, r, err := n.rt.TrySplit(f)
 	if err != nil {
 		return nil, nil, err
@@ -791,18 +824,18 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS
 	if r == nil {
 		// If r is nil then the split failed/returned an empty residual, but
 		// we can still split at a window boundary.
-		return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
+		return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState)
 	}
 
 	// Split of currently processing restriction in a single window.
 	ps := make([]*FullValue, 1)
-	newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1], pWeState)
+	newP, err := n.newSplitResult(ctx, p, n.elm.Windows[n.currW:n.currW+1], pWeState)
 	if err != nil {
 		return nil, nil, err
 	}
 	ps[0] = newP
 	rs := make([]*FullValue, 1)
-	newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1], rWeState)
+	newR, err := n.newSplitResult(ctx, r, n.elm.Windows[n.currW:n.currW+1], rWeState)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -810,14 +843,14 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS
 	// Window boundary split surrounding the split restriction above.
 	full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
 	if 0 < n.currW {
-		newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW], pWeState)
+		newP, err := n.newSplitResult(ctx, full, n.elm.Windows[0:n.currW], pWeState)
 		if err != nil {
 			return nil, nil, err
 		}
 		ps = append(ps, newP)
 	}
 	if n.currW+1 < n.numW {
-		newR, err := n.newSplitResult(full, n.elm.Windows[n.currW+1:n.numW], rWeState)
+		newR, err := n.newSplitResult(ctx, full, n.elm.Windows[n.currW+1:n.numW], rWeState)
 		if err != nil {
 			return nil, nil, err
 		}
@@ -830,17 +863,17 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS
 // windowBoundarySplit performs an appropriate split at a window boundary. The
 // split point taken should be the index of the first window in the residual.
 // Also updates numW to stop at the split point.
-func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(ctx context.Context, splitPt int, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
 	// If this is at the boundary of the last window, split is a no-op.
 	if splitPt == n.numW {
 		return []*FullValue{}, []*FullValue{}, nil
 	}
 	full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-	pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt], pWeState)
+	pFv, err := n.newSplitResult(ctx, full, n.elm.Windows[0:splitPt], pWeState)
 	if err != nil {
 		return nil, nil, err
 	}
-	rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW], rWeState)
+	rFv, err := n.newSplitResult(ctx, full, n.elm.Windows[splitPt:n.numW], rWeState)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -852,18 +885,27 @@ func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, p
 // element restriction pair based on the currently processing element, but with
 // a modified restriction and windows. Intended for creating primaries and
 // residuals to return as split results.
-func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest any, w []typex.Window, weState any) (*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) newSplitResult(ctx context.Context, rest any, w []typex.Window, weState any) (*FullValue, error) {
 	var size float64
+	var err error
 	elm := n.elm.Elm.(*FullValue).Elm
 	if fv, ok := elm.(*FullValue); ok {
-		size = n.sizeInv.Invoke(fv, rest)
+		size, err = n.sizeInv.Invoke(ctx, fv, rest)
+		if err != nil {
+			return nil, err
+		}
+
 		if size < 0 {
 			err := errors.Errorf("size returned expected to be non-negative but received %v.", size)
 			return nil, errors.WithContextf(err, "%v", n)
 		}
 	} else {
 		fv := &FullValue{Elm: elm}
-		size = n.sizeInv.Invoke(fv, rest)
+		size, err = n.sizeInv.Invoke(ctx, fv, rest)
+		if err != nil {
+			return nil, err
+		}
+
 		if size < 0 {
 			err := errors.Errorf("size returned expected to be non-negative but received %v.", size)
 			return nil, errors.WithContextf(err, "%v", n)
@@ -973,21 +1015,33 @@ func (n *SdfFallback) StartBundle(ctx context.Context, id string, data DataConte
 // restrictions, and then creating restriction trackers and processing each
 // restriction with the underlying ParDo. This executor skips the sizing step
 // because sizing information is unnecessary for unexpanded SDFs.
-func (n *SdfFallback) ProcessElement(_ context.Context, elm *FullValue, values ...ReStream) error {
+func (n *SdfFallback) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
 	if n.PDo.status != Active {
 		err := errors.Errorf("invalid status %v, want Active", n.PDo.status)
 		return errors.WithContextf(err, "%v", n)
 	}
 
-	rest := n.initRestInv.Invoke(elm)
-	splitRests := n.splitInv.Invoke(elm, rest)
+	rest, err := n.initRestInv.Invoke(ctx, elm)
+	if err != nil {
+		return err
+	}
+
+	splitRests, err := n.splitInv.Invoke(ctx, elm, rest)
+	if err != nil {
+		return err
+	}
+
 	if len(splitRests) == 0 {
 		err := errors.Errorf("initial splitting returned 0 restrictions.")
 		return errors.WithContextf(err, "%v", n)
 	}
 
 	for _, splitRest := range splitRests {
-		rt := n.trackerInv.Invoke(splitRest)
+		rt, err := n.trackerInv.Invoke(ctx, splitRest)
+		if err != nil {
+			return err
+		}
+
 		mainIn := &MainInput{
 			Key:      *elm,
 			Values:   values,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
index 5d58f198a64..2dd894ed08c 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
@@ -16,6 +16,7 @@
 package exec
 
 import (
+	"context"
 	"reflect"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
@@ -24,6 +25,9 @@ import (
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
 )
 
+//go:generate specialize --input=sdf_invokers_arity.tmpl
+//go:generate gofmt -w sdf_invokers_arity.go
+
 // This file contains invokers for SDF methods. These invokers are based off
 // exec.invoker which is used for regular DoFns. Since exec.invoker is
 // specialized for DoFns it cannot be used for SDF methods. Instead, these
@@ -39,9 +43,10 @@ import (
 
 // cirInvoker is an invoker for CreateInitialRestriction.
 type cirInvoker struct {
-	fn   *funcx.Fn
-	args []any // Cache to avoid allocating new slices per-element.
-	call func(elms *FullValue) (rest any)
+	fn     *funcx.Fn
+	args   []any // Cache to avoid allocating new slices per-element.
+	ctxIdx int
+	call   func() (rest any, err error)
 }
 
 func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) {
@@ -49,51 +54,34 @@ func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) {
 		fn:   fn,
 		args: make([]any, len(fn.Param)),
 	}
+
+	var ok bool
+	if n.ctxIdx, ok = fn.Context(); !ok {
+		n.ctxIdx = -1
+	}
+
 	if err := n.initCallFn(); err != nil {
 		return nil, errors.WithContext(err, "sdf CreateInitialRestriction invoker")
 	}
 	return n, nil
 }
 
-func (n *cirInvoker) initCallFn() error {
-	// Expects a signature of the form:
-	// (key?, value) restriction
-	// TODO(BEAM-9643): Link to full documentation.
-	switch fnT := n.fn.Fn.(type) {
-	case reflectx.Func1x1:
-		n.call = func(elms *FullValue) any {
-			return fnT.Call1x1(elms.Elm)
-		}
-	case reflectx.Func2x1:
-		n.call = func(elms *FullValue) any {
-			return fnT.Call2x1(elms.Elm, elms.Elm2)
-		}
-	default:
-		switch len(n.fn.Param) {
-		case 1:
-			n.call = func(elms *FullValue) any {
-				n.args[0] = elms.Elm
-				return n.fn.Fn.Call(n.args)[0]
-			}
-		case 2:
-			n.call = func(elms *FullValue) any {
-				n.args[0] = elms.Elm
-				n.args[1] = elms.Elm2
-				return n.fn.Fn.Call(n.args)[0]
-			}
-		default:
-			return errors.Errorf("CreateInitialRestriction fn %v has unexpected number of parameters: %v",
-				n.fn.Fn.Name(), len(n.fn.Param))
-		}
+// Invoke calls CreateInitialRestriction with the given FullValue as the element
+// and returns the resulting restriction.
+func (n *cirInvoker) Invoke(ctx context.Context, elms *FullValue) (rest any, err error) {
+	if n.ctxIdx >= 0 {
+		n.args[n.ctxIdx] = ctx
 	}
 
-	return nil
-}
+	i := n.ctxIdx + 1
+	n.args[i] = elms.Elm
 
-// Invoke calls CreateInitialRestriction with the given FullValue as the element
-// and returns the resulting restriction.
-func (n *cirInvoker) Invoke(elms *FullValue) (rest any) {
-	return n.call(elms)
+	if elms.Elm2 != nil {
+		i++
+		n.args[i] = elms.Elm2
+	}
+
+	return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -106,9 +94,10 @@ func (n *cirInvoker) Reset() {
 
 // srInvoker is an invoker for SplitRestriction.
 type srInvoker struct {
-	fn   *funcx.Fn
-	args []any // Cache to avoid allocating new slices per-element.
-	call func(elms *FullValue, rest any) (splits any)
+	fn     *funcx.Fn
+	args   []any // Cache to avoid allocating new slices per-element.
+	ctxIdx int
+	call   func() (splits any, err error)
 }
 
 func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) {
@@ -116,52 +105,40 @@ func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) {
 		fn:   fn,
 		args: make([]any, len(fn.Param)),
 	}
+
+	var ok bool
+	if n.ctxIdx, ok = fn.Context(); !ok {
+		n.ctxIdx = -1
+	}
+
 	if err := n.initCallFn(); err != nil {
 		return nil, errors.WithContext(err, "sdf SplitRestriction invoker")
 	}
 	return n, nil
 }
 
-func (n *srInvoker) initCallFn() error {
-	// Expects a signature of the form:
-	// (key?, value, restriction) []restriction
-	// TODO(BEAM-9643): Link to full documentation.
-	switch fnT := n.fn.Fn.(type) {
-	case reflectx.Func2x1:
-		n.call = func(elms *FullValue, rest any) any {
-			return fnT.Call2x1(elms.Elm, rest)
-		}
-	case reflectx.Func3x1:
-		n.call = func(elms *FullValue, rest any) any {
-			return fnT.Call3x1(elms.Elm, elms.Elm2, rest)
-		}
-	default:
-		switch len(n.fn.Param) {
-		case 2:
-			n.call = func(elms *FullValue, rest any) any {
-				n.args[0] = elms.Elm
-				n.args[1] = rest
-				return n.fn.Fn.Call(n.args)[0]
-			}
-		case 3:
-			n.call = func(elms *FullValue, rest any) any {
-				n.args[0] = elms.Elm
-				n.args[1] = elms.Elm2
-				n.args[2] = rest
-				return n.fn.Fn.Call(n.args)[0]
-			}
-		default:
-			return errors.Errorf("SplitRestriction fn %v has unexpected number of parameters: %v",
-				n.fn.Fn.Name(), len(n.fn.Param))
-		}
-	}
-	return nil
-}
-
 // Invoke calls SplitRestriction given a FullValue containing an element and
 // the associated restriction, and returns a slice of split restrictions.
-func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) {
-	ret := n.call(elms, rest)
+func (n *srInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) (splits []any, err error) {
+	if n.ctxIdx >= 0 {
+		n.args[n.ctxIdx] = ctx
+	}
+
+	i := n.ctxIdx + 1
+	n.args[i] = elms.Elm
+
+	if elms.Elm2 != nil {
+		i++
+		n.args[i] = elms.Elm2
+	}
+
+	i++
+	n.args[i] = rest
+
+	ret, err := n.call()
+	if err != nil {
+		return nil, err
+	}
 
 	// Return value is an any, but we need to convert it to a []any.
 	val := reflect.ValueOf(ret)
@@ -169,7 +146,7 @@ func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) {
 	for i := 0; i < val.Len(); i++ {
 		s = append(s, val.Index(i).Interface())
 	}
-	return s
+	return s, nil
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -182,9 +159,10 @@ func (n *srInvoker) Reset() {
 
 // rsInvoker is an invoker for RestrictionSize.
 type rsInvoker struct {
-	fn   *funcx.Fn
-	args []any // Cache to avoid allocating new slices per-element.
-	call func(elms *FullValue, rest any) (size float64)
+	fn     *funcx.Fn
+	args   []any // Cache to avoid allocating new slices per-element.
+	ctxIdx int
+	call   func() (size float64, err error)
 }
 
 func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) {
@@ -192,52 +170,37 @@ func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) {
 		fn:   fn,
 		args: make([]any, len(fn.Param)),
 	}
+
+	var ok bool
+	if n.ctxIdx, ok = fn.Context(); !ok {
+		n.ctxIdx = -1
+	}
+
 	if err := n.initCallFn(); err != nil {
 		return nil, errors.WithContext(err, "sdf RestrictionSize invoker")
 	}
 	return n, nil
 }
 
-func (n *rsInvoker) initCallFn() error {
-	// Expects a signature of the form:
-	// (key?, value, restriction) float64
-	// TODO(BEAM-9643): Link to full documentation.
-	switch fnT := n.fn.Fn.(type) {
-	case reflectx.Func2x1:
-		n.call = func(elms *FullValue, rest any) float64 {
-			return fnT.Call2x1(elms.Elm, rest).(float64)
-		}
-	case reflectx.Func3x1:
-		n.call = func(elms *FullValue, rest any) float64 {
-			return fnT.Call3x1(elms.Elm, elms.Elm2, rest).(float64)
-		}
-	default:
-		switch len(n.fn.Param) {
-		case 2:
-			n.call = func(elms *FullValue, rest any) float64 {
-				n.args[0] = elms.Elm
-				n.args[1] = rest
-				return n.fn.Fn.Call(n.args)[0].(float64)
-			}
-		case 3:
-			n.call = func(elms *FullValue, rest any) float64 {
-				n.args[0] = elms.Elm
-				n.args[1] = elms.Elm2
-				n.args[2] = rest
-				return n.fn.Fn.Call(n.args)[0].(float64)
-			}
-		default:
-			return errors.Errorf("RestrictionSize fn %v has unexpected number of parameters: %v",
-				n.fn.Fn.Name(), len(n.fn.Param))
-		}
-	}
-	return nil
-}
-
 // Invoke calls RestrictionSize given a FullValue containing an element and
 // the associated restriction, and returns a size.
-func (n *rsInvoker) Invoke(elms *FullValue, rest any) (size float64) {
-	return n.call(elms, rest)
+func (n *rsInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) (size float64, err error) {
+	if n.ctxIdx >= 0 {
+		n.args[n.ctxIdx] = ctx
+	}
+
+	i := n.ctxIdx + 1
+	n.args[i] = elms.Elm
+
+	if elms.Elm2 != nil {
+		i++
+		n.args[i] = elms.Elm2
+	}
+
+	i++
+	n.args[i] = rest
+
+	return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -250,9 +213,10 @@ func (n *rsInvoker) Reset() {
 
 // ctInvoker is an invoker for CreateTracker.
 type ctInvoker struct {
-	fn   *funcx.Fn
-	args []any // Cache to avoid allocating new slices per-element.
-	call func(rest any) sdf.RTracker
+	fn     *funcx.Fn
+	args   []any // Cache to avoid allocating new slices per-element.
+	ctxIdx int
+	call   func() (rt sdf.RTracker, err error)
 }
 
 func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) {
@@ -260,37 +224,27 @@ func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) {
 		fn:   fn,
 		args: make([]any, len(fn.Param)),
 	}
+
+	var ok bool
+	if n.ctxIdx, ok = fn.Context(); !ok {
+		n.ctxIdx = -1
+	}
+
 	if err := n.initCallFn(); err != nil {
 		return nil, errors.WithContext(err, "sdf CreateTracker invoker")
 	}
 	return n, nil
 }
 
-func (n *ctInvoker) initCallFn() error {
-	// Expects a signature of the form:
-	// (restriction) sdf.RTracker
-	// TODO(BEAM-9643): Link to full documentation.
-	switch fnT := n.fn.Fn.(type) {
-	case reflectx.Func1x1:
-		n.call = func(rest any) sdf.RTracker {
-			return fnT.Call1x1(rest).(sdf.RTracker)
-		}
-	default:
-		if len(n.fn.Param) != 1 {
-			return errors.Errorf("CreateTracker fn %v has unexpected number of parameters: %v",
-				n.fn.Fn.Name(), len(n.fn.Param))
-		}
-		n.call = func(rest any) sdf.RTracker {
-			n.args[0] = rest
-			return n.fn.Fn.Call(n.args)[0].(sdf.RTracker)
-		}
+// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker.
+func (n *ctInvoker) Invoke(ctx context.Context, rest any) (sdf.RTracker, error) {
+	if n.ctxIdx >= 0 {
+		n.args[n.ctxIdx] = ctx
 	}
-	return nil
-}
 
-// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker.
-func (n *ctInvoker) Invoke(rest any) sdf.RTracker {
-	return n.call(rest)
+	n.args[n.ctxIdx+1] = rest
+
+	return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -303,9 +257,10 @@ func (n *ctInvoker) Reset() {
 
 // trInvoker is an invoker for TruncateRestriction.
 type trInvoker struct {
-	fn   *funcx.Fn
-	args []any
-	call func(rest any, elms *FullValue) (pair any)
+	fn     *funcx.Fn
+	args   []any
+	ctxIdx int
+	call   func() (rest any, err error)
 }
 
 func defaultTruncateRestriction(restTracker any) (newRest any) {
@@ -320,6 +275,12 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) {
 		fn:   fn,
 		args: make([]any, len(fn.Param)),
 	}
+
+	var ok bool
+	if n.ctxIdx, ok = fn.Context(); !ok {
+		n.ctxIdx = -1
+	}
+
 	if err := n.initCallFn(); err != nil {
 		return nil, errors.WithContext(err, "sdf TruncateRestriction invoker")
 	}
@@ -327,53 +288,39 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) {
 }
 
 func newDefaultTruncateRestrictionInvoker() (*trInvoker, error) {
-	n := &trInvoker{}
-	n.call = func(rest any, elms *FullValue) any {
-		return defaultTruncateRestriction(rest)
+	n := &trInvoker{
+		args: make([]any, 1),
 	}
-	return n, nil
-}
-
-func (n *trInvoker) initCallFn() error {
-	// Expects a signature of the form:
-	// (key?, value, restriction) []restriction
-	// TODO(BEAM-9643): Link to full documentation.
-	switch fnT := n.fn.Fn.(type) {
-	case reflectx.Func2x1:
-		n.call = func(rest any, elms *FullValue) any {
-			return fnT.Call2x1(rest, elms.Elm)
-		}
-	case reflectx.Func3x1:
-		n.call = func(rest any, elms *FullValue) any {
-			return fnT.Call3x1(rest, elms.Elm, elms.Elm2)
-		}
-	default:
-		switch len(n.fn.Param) {
-		case 2:
-			n.call = func(rest any, elms *FullValue) any {
-				n.args[0] = rest
-				n.args[1] = elms.Elm
-				return n.fn.Fn.Call(n.args)[0]
-			}
-		case 3:
-			n.call = func(rest any, elms *FullValue) any {
-				n.args[0] = rest
-				n.args[1] = elms.Elm
-				n.args[2] = elms.Elm2
-				return n.fn.Fn.Call(n.args)[0]
-			}
-		default:
-			return errors.Errorf("TruncateRestriction fn %v has unexpected number of parameters: %v",
-				n.fn.Fn.Name(), len(n.fn.Param))
-		}
+	n.call = func() (any, error) {
+		return defaultTruncateRestriction(n.args[0]), nil
 	}
-	return nil
+	return n, nil
 }
 
 // Invoke calls TruncateRestriction given a FullValue containing an element and
 // the associated restriction tracker, and returns a truncated restriction.
-func (n *trInvoker) Invoke(rt any, elms *FullValue) (rest any) {
-	return n.call(rt, elms)
+func (n *trInvoker) Invoke(ctx context.Context, rt any, elms *FullValue) (rest any, err error) {
+	if n.fn == nil {
+		n.args[0] = rt
+		return n.call()
+	}
+
+	if n.ctxIdx >= 0 {
+		n.args[n.ctxIdx] = ctx
+	}
+
+	i := n.ctxIdx + 1
+	n.args[i] = rt
+
+	i++
+	n.args[i] = elms.Elm
+
+	if elms.Elm2 != nil {
+		i++
+		n.args[i] = elms.Elm2
+	}
+
+	return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -589,3 +536,11 @@ func (n *wesInvoker) Reset() {
 		n.args[i] = nil
 	}
 }
+
+func asError(val any) error {
+	if val != nil {
+		return val.(error)
+	}
+
+	return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go
new file mode 100644
index 00000000000..cdefa711603
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go
@@ -0,0 +1,336 @@
+// File generated by specialize. Do not edit.
+
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//	http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT.
+
+package exec
+
+import (
+	"fmt"
+
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+)
+
+func (n *cirInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, key?, value) (restriction, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+
+	case reflectx.Func1x1:
+		n.call = func() (rest any, err error) {
+			r0 := fnT.Call1x1(n.args[0])
+			return r0, nil
+		}
+
+	case reflectx.Func2x1:
+		n.call = func() (rest any, err error) {
+			r0 := fnT.Call2x1(n.args[0], n.args[1])
+			return r0, nil
+		}
+
+	case reflectx.Func3x1:
+		n.call = func() (rest any, err error) {
+			r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+			return r0, nil
+		}
+
+	case reflectx.Func1x2:
+		n.call = func() (rest any, err error) {
+			r0, r1 := fnT.Call1x2(n.args[0])
+			return r0, asError(r1)
+		}
+
+	case reflectx.Func2x2:
+		n.call = func() (rest any, err error) {
+			r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+			return r0, asError(r1)
+		}
+
+	case reflectx.Func3x2:
+		n.call = func() (rest any, err error) {
+			r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+			return r0, asError(r1)
+		}
+
+	default:
+		if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 {
+			return errors.Errorf("CreateInitialRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (rest any, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0], nil
+			case 2:
+				return ret[0], asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("CreateInitialRestriction has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *srInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, key?, value, restriction) ([]restriction, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+
+	case reflectx.Func2x1:
+		n.call = func() (splits any, err error) {
+			r0 := fnT.Call2x1(n.args[0], n.args[1])
+			return r0, nil
+		}
+
+	case reflectx.Func3x1:
+		n.call = func() (splits any, err error) {
+			r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+			return r0, nil
+		}
+
+	case reflectx.Func4x1:
+		n.call = func() (splits any, err error) {
+			r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3])
+			return r0, nil
+		}
+
+	case reflectx.Func2x2:
+		n.call = func() (splits any, err error) {
+			r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+			return r0, asError(r1)
+		}
+
+	case reflectx.Func3x2:
+		n.call = func() (splits any, err error) {
+			r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+			return r0, asError(r1)
+		}
+
+	case reflectx.Func4x2:
+		n.call = func() (splits any, err error) {
+			r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3])
+			return r0, asError(r1)
+		}
+
+	default:
+		if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+			return errors.Errorf("SplitRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (splits any, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0], nil
+			case 2:
+				return ret[0], asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("SplitRestriction has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *rsInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, key?, value, restriction) (float64, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+
+	case reflectx.Func2x1:
+		n.call = func() (size float64, err error) {
+			r0 := fnT.Call2x1(n.args[0], n.args[1])
+			return r0.(float64), nil
+		}
+
+	case reflectx.Func3x1:
+		n.call = func() (size float64, err error) {
+			r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+			return r0.(float64), nil
+		}
+
+	case reflectx.Func4x1:
+		n.call = func() (size float64, err error) {
+			r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3])
+			return r0.(float64), nil
+		}
+
+	case reflectx.Func2x2:
+		n.call = func() (size float64, err error) {
+			r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+			return r0.(float64), asError(r1)
+		}
+
+	case reflectx.Func3x2:
+		n.call = func() (size float64, err error) {
+			r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+			return r0.(float64), asError(r1)
+		}
+
+	case reflectx.Func4x2:
+		n.call = func() (size float64, err error) {
+			r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3])
+			return r0.(float64), asError(r1)
+		}
+
+	default:
+		if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+			return errors.Errorf("RestrictionSize has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (size float64, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0].(float64), nil
+			case 2:
+				return ret[0].(float64), asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("RestrictionSize has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *ctInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, restriction) (sdf.RTracker, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+
+	case reflectx.Func1x1:
+		n.call = func() (rt sdf.RTracker, err error) {
+			r0 := fnT.Call1x1(n.args[0])
+			return r0.(sdf.RTracker), nil
+		}
+
+	case reflectx.Func2x1:
+		n.call = func() (rt sdf.RTracker, err error) {
+			r0 := fnT.Call2x1(n.args[0], n.args[1])
+			return r0.(sdf.RTracker), nil
+		}
+
+	case reflectx.Func1x2:
+		n.call = func() (rt sdf.RTracker, err error) {
+			r0, r1 := fnT.Call1x2(n.args[0])
+			return r0.(sdf.RTracker), asError(r1)
+		}
+
+	case reflectx.Func2x2:
+		n.call = func() (rt sdf.RTracker, err error) {
+			r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+			return r0.(sdf.RTracker), asError(r1)
+		}
+
+	default:
+		if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 {
+			return errors.Errorf("CreateTracker has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (rt sdf.RTracker, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0].(sdf.RTracker), nil
+			case 2:
+				return ret[0].(sdf.RTracker), asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("CreateTracker has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *trInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, sdf.RTracker, key?, value) (restriction, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+
+	case reflectx.Func2x1:
+		n.call = func() (rest any, err error) {
+			r0 := fnT.Call2x1(n.args[0], n.args[1])
+			return r0, nil
+		}
+
+	case reflectx.Func3x1:
+		n.call = func() (rest any, err error) {
+			r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+			return r0, nil
+		}
+
+	case reflectx.Func4x1:
+		n.call = func() (rest any, err error) {
+			r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3])
+			return r0, nil
+		}
+
+	case reflectx.Func2x2:
+		n.call = func() (rest any, err error) {
+			r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+			return r0, asError(r1)
+		}
+
+	case reflectx.Func3x2:
+		n.call = func() (rest any, err error) {
+			r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+			return r0, asError(r1)
+		}
+
+	case reflectx.Func4x2:
+		n.call = func() (rest any, err error) {
+			r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3])
+			return r0, asError(r1)
+		}
+
+	default:
+		if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+			return errors.Errorf("TruncateRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (rest any, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0], nil
+			case 2:
+				return ret[0], asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("TruncateRestriction has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl
new file mode 100644
index 00000000000..7df994be0c7
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//	http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT.
+
+package exec
+
+import (
+	"fmt"
+
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+)
+
+func (n *cirInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, key?, value) (restriction, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 4}}
+    {{if gt $out 0}}
+    {{if gt $in 0}}
+	case reflectx.Func{{$in}}x{{$out}}:
+		n.call = func() (rest any, err error) {
+			{{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+			{{- if eq $out 1}}
+			return r0, nil
+			{{- else}}
+			return r0, asError(r1)
+			{{- end}}
+		}
+	{{end}}
+	{{end}}
+{{end}}
+{{end}}
+	default:
+		if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 {
+			return errors.Errorf("CreateInitialRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (rest any, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0], nil
+			case 2:
+				return ret[0], asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("CreateInitialRestriction has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *srInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, key?, value, restriction) ([]restriction, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+    {{if gt $out 0}}
+    {{if gt $in 1}}
+	case reflectx.Func{{$in}}x{{$out}}:
+		n.call = func() (splits any, err error) {
+			{{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+			{{- if eq $out 1}}
+			return r0, nil
+			{{- else}}
+			return r0, asError(r1)
+			{{- end}}
+		}
+	{{end}}
+	{{end}}
+{{end}}
+{{end}}
+	default:
+		if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+			return errors.Errorf("SplitRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (splits any, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0], nil
+			case 2:
+				return ret[0], asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("SplitRestriction has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *rsInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, key?, value, restriction) (float64, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+    {{if gt $out 0}}
+    {{if gt $in 1}}
+	case reflectx.Func{{$in}}x{{$out}}:
+		n.call = func() (size float64, err error) {
+			{{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+			{{- if eq $out 1}}
+			return r0.(float64), nil
+			{{- else}}
+			return r0.(float64), asError(r1)
+			{{- end}}
+		}
+	{{end}}
+	{{end}}
+{{end}}
+{{end}}
+	default:
+		if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+			return errors.Errorf("RestrictionSize has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (size float64, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0].(float64), nil
+			case 2:
+				return ret[0].(float64), asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("RestrictionSize has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *ctInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, restriction) (sdf.RTracker, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 3}}
+    {{if gt $out 0}}
+    {{if gt $in 0}}
+	case reflectx.Func{{$in}}x{{$out}}:
+		n.call = func() (rt sdf.RTracker, err error) {
+			{{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+			{{- if eq $out 1}}
+			return r0.(sdf.RTracker), nil
+			{{- else}}
+			return r0.(sdf.RTracker), asError(r1)
+			{{- end}}
+		}
+	{{end}}
+	{{end}}
+{{end}}
+{{end}}
+	default:
+		if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 {
+			return errors.Errorf("CreateTracker has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (rt sdf.RTracker, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0].(sdf.RTracker), nil
+			case 2:
+				return ret[0].(sdf.RTracker), asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("CreateTracker has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
+
+func (n *trInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (context.Context?, sdf.RTracker, key?, value) (restriction, error?)
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+    {{if gt $out 0}}
+    {{if gt $in 1}}
+	case reflectx.Func{{$in}}x{{$out}}:
+		n.call = func() (rest any, err error) {
+			{{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+			{{- if eq $out 1}}
+			return r0, nil
+			{{- else}}
+			return r0, asError(r1)
+			{{- end}}
+		}
+	{{end}}
+	{{end}}
+{{end}}
+{{end}}
+	default:
+		if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+			return errors.Errorf("TruncateRestriction has unexpected number of parameters: %v", len(n.fn.Param))
+		}
+
+		n.call = func() (rest any, err error) {
+			ret := n.fn.Fn.Call(n.args)
+
+			switch len(ret) {
+			case 1:
+				return ret[0], nil
+			case 2:
+				return ret[0], asError(ret[1])
+			}
+
+			panic(fmt.Sprintf("TruncateRestriction has unexpected number of return values: %v", len(ret)))
+		}
+	}
+
+	return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
index e308071aaf9..edef16e51d7 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
@@ -16,6 +16,8 @@
 package exec
 
 import (
+	"context"
+	"errors"
 	"testing"
 	"time"
 
@@ -48,13 +50,59 @@ func TestInvokes(t *testing.T) {
 	}
 	statefulWeFn := (*graph.SplittableDoFn)(dfn)
 
+	initialRestErrDfn, err := graph.NewDoFn(
+		&VetCreateInitialRestrictionErrSdf{},
+		graph.NumMainInputs(graph.MainSingle),
+	)
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
+	initialRestErrSdf := (*graph.SplittableDoFn)(initialRestErrDfn)
+
+	splitRestErrDfn, err := graph.NewDoFn(
+		&VetSplitRestrictionErrSdf{},
+		graph.NumMainInputs(graph.MainSingle),
+	)
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
+	splitRestErrSdf := (*graph.SplittableDoFn)(splitRestErrDfn)
+
+	restSizeErrDfn, err := graph.NewDoFn(
+		&VetRestrictionSizeErrSdf{},
+		graph.NumMainInputs(graph.MainSingle),
+	)
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
+	restSizeErrSdf := (*graph.SplittableDoFn)(restSizeErrDfn)
+
+	trackerErrDfn, err := graph.NewDoFn(
+		&VetCreateTrackerErrSdf{},
+		graph.NumMainInputs(graph.MainSingle),
+	)
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
+	trackerErrSdf := (*graph.SplittableDoFn)(trackerErrDfn)
+
+	truncateRestErrDfn, err := graph.NewDoFn(
+		&VetTruncateRestrictionErrSdf{},
+		graph.NumMainInputs(graph.MainSingle),
+	)
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
+	truncateRestErrSdf := (*graph.SplittableDoFn)(truncateRestErrDfn)
+
 	// Tests.
 	t.Run("CreateInitialRestriction Invoker (cirInvoker)", func(t *testing.T) {
 		tests := []struct {
-			name string
-			sdf  *graph.SplittableDoFn
-			elms *FullValue
-			want *VetRestriction
+			name    string
+			sdf     *graph.SplittableDoFn
+			elms    *FullValue
+			want    *VetRestriction
+			wantErr bool
 		}{
 			{
 				name: "SingleElem",
@@ -68,6 +116,12 @@ func TestInvokes(t *testing.T) {
 				elms: &FullValue{Elm: 1, Elm2: 2},
 				want: &VetRestriction{ID: "KvSdf", CreateRest: true, Key: 1, Val: 2},
 			},
+			{
+				name:    "Error",
+				sdf:     initialRestErrSdf,
+				elms:    &FullValue{Elm: 1},
+				wantErr: true,
+			},
 		}
 		for _, test := range tests {
 			test := test
@@ -77,7 +131,15 @@ func TestInvokes(t *testing.T) {
 				if err != nil {
 					t.Fatalf("newCreateInitialRestrictionInvoker failed: %v", err)
 				}
-				got := invoker.Invoke(test.elms)
+
+				got, err := invoker.Invoke(context.Background(), test.elms)
+				if (err != nil) != test.wantErr {
+					t.Fatalf("Invoke(%v) error = %v, wantErr %v", test.elms, err, test.wantErr)
+				}
+				if test.wantErr {
+					return
+				}
+
 				if !cmp.Equal(got, test.want) {
 					t.Errorf("Invoke(%v) has incorrect output: got: %v, want: %v",
 						test.elms, got, test.want)
@@ -94,11 +156,12 @@ func TestInvokes(t *testing.T) {
 
 	t.Run("SplitRestriction Invoker (srInvoker)", func(t *testing.T) {
 		tests := []struct {
-			name string
-			sdf  *graph.SplittableDoFn
-			elms *FullValue
-			rest *VetRestriction
-			want []any
+			name    string
+			sdf     *graph.SplittableDoFn
+			elms    *FullValue
+			rest    *VetRestriction
+			want    []any
+			wantErr bool
 		}{
 			{
 				name: "SingleElem",
@@ -119,6 +182,13 @@ func TestInvokes(t *testing.T) {
 					&VetRestriction{ID: "KvSdf.2", SplitRest: true, Key: 1, Val: 2},
 				},
 			},
+			{
+				name:    "Error",
+				sdf:     splitRestErrSdf,
+				elms:    &FullValue{Elm: 1},
+				rest:    &VetRestriction{ID: "Sdf"},
+				wantErr: true,
+			},
 		}
 		for _, test := range tests {
 			test := test
@@ -128,8 +198,16 @@ func TestInvokes(t *testing.T) {
 				if err != nil {
 					t.Fatalf("newSplitRestrictionInvoker failed: %v", err)
 				}
+
 				rest := *test.rest // Create a copy because our test SDF edits the restriction.
-				got := invoker.Invoke(test.elms, &rest)
+				got, err := invoker.Invoke(context.Background(), test.elms, &rest)
+				if (err != nil) != test.wantErr {
+					t.Fatalf("Invoke(%v, %v) error = %v, wantErr %v", test.elms, test.rest, err, test.wantErr)
+				}
+				if test.wantErr {
+					return
+				}
+
 				if !cmp.Equal(got, test.want) {
 					t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
 						test.elms, test.rest, got, test.want)
@@ -152,6 +230,7 @@ func TestInvokes(t *testing.T) {
 			rest     *VetRestriction
 			want     float64
 			restWant *VetRestriction
+			wantErr  bool
 		}{
 			{
 				name:     "SingleElem",
@@ -168,6 +247,13 @@ func TestInvokes(t *testing.T) {
 				want:     3,
 				restWant: &VetRestriction{ID: "KvSdf", RestSize: true, Key: 1, Val: 2},
 			},
+			{
+				name:    "Error",
+				sdf:     restSizeErrSdf,
+				elms:    &FullValue{Elm: 1},
+				rest:    &VetRestriction{ID: "Sdf"},
+				wantErr: true,
+			},
 		}
 		for _, test := range tests {
 			test := test
@@ -178,7 +264,15 @@ func TestInvokes(t *testing.T) {
 					t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
 				}
 				rest := *test.rest // Create a copy because our test SDF edits the restriction.
-				got := invoker.Invoke(test.elms, &rest)
+
+				got, err := invoker.Invoke(context.Background(), test.elms, &rest)
+				if (err != nil) != test.wantErr {
+					t.Fatalf("Invoke(%v, %v) error = %v, wantErr %v", test.elms, test.rest, err, test.wantErr)
+				}
+				if test.wantErr {
+					return
+				}
+
 				if !cmp.Equal(got, test.want) {
 					t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
 						test.elms, test.rest, got, test.want)
@@ -199,10 +293,11 @@ func TestInvokes(t *testing.T) {
 
 	t.Run("CreateTracker Invoker (ctInvoker)", func(t *testing.T) {
 		tests := []struct {
-			name string
-			sdf  *graph.SplittableDoFn
-			rest *VetRestriction
-			want *VetRTracker
+			name    string
+			sdf     *graph.SplittableDoFn
+			rest    *VetRestriction
+			want    *VetRTracker
+			wantErr bool
 		}{
 			{
 				name: "SingleElem",
@@ -215,6 +310,12 @@ func TestInvokes(t *testing.T) {
 				rest: &VetRestriction{ID: "KvSdf"},
 				want: &VetRTracker{&VetRestriction{ID: "KvSdf", CreateTracker: true}},
 			},
+			{
+				name:    "Error",
+				sdf:     trackerErrSdf,
+				rest:    &VetRestriction{ID: "Sdf"},
+				wantErr: true,
+			},
 		}
 		for _, test := range tests {
 			test := test
@@ -224,7 +325,15 @@ func TestInvokes(t *testing.T) {
 				if err != nil {
 					t.Fatalf("newCreateTrackerInvoker failed: %v", err)
 				}
-				got := invoker.Invoke(test.rest)
+
+				got, err := invoker.Invoke(context.Background(), test.rest)
+				if (err != nil) != test.wantErr {
+					t.Fatalf("Invoke(%v) error = %v, wantErr %v", test.rest, err, test.wantErr)
+				}
+				if test.wantErr {
+					return
+				}
+
 				if !cmp.Equal(got, test.want) {
 					t.Errorf("Invoke(%v) has incorrect output: got: %v, want: %v",
 						test.rest, got, test.want)
@@ -310,11 +419,12 @@ func TestInvokes(t *testing.T) {
 
 	t.Run("TruncateRestriction Invoker (trInvoker)", func(t *testing.T) {
 		tests := []struct {
-			name string
-			sdf  *graph.SplittableDoFn
-			elms *FullValue
-			rest *VetRestriction
-			want any
+			name    string
+			sdf     *graph.SplittableDoFn
+			elms    *FullValue
+			rest    *VetRestriction
+			want    any
+			wantErr bool
 		}{
 			{
 				name: "SingleElem",
@@ -329,6 +439,13 @@ func TestInvokes(t *testing.T) {
 				rest: &VetRestriction{ID: "KvSdf"},
 				want: &VetRestriction{ID: "KvSdf", CreateTracker: true, TruncateRest: true, RestSize: true, Key: 1, Val: 2},
 			},
+			{
+				name:    "Error",
+				sdf:     truncateRestErrSdf,
+				elms:    &FullValue{Elm: 1},
+				rest:    &VetRestriction{ID: "Sdf"},
+				wantErr: true,
+			},
 		}
 		for _, test := range tests {
 			test := test
@@ -336,24 +453,38 @@ func TestInvokes(t *testing.T) {
 			ctFn := test.sdf.CreateTrackerFn()
 			rsFn := test.sdf.RestrictionSizeFn()
 			t.Run(test.name, func(t *testing.T) {
+				ctx := context.Background()
 				rest := test.rest // Create a copy because our test SDF edits the restriction.
 				ctInvoker, err := newCreateTrackerInvoker(ctFn)
 				if err != nil {
 					t.Fatalf("newCreateTrackerInvoker failed: %v", err)
 				}
-				rt := ctInvoker.Invoke(rest)
+				rt, err := ctInvoker.Invoke(ctx, rest)
+				if err != nil {
+					t.Fatalf("ctInvoker.Invoke(%v) failed: %v", rest, err)
+				}
 
 				trInvoker, err := newTruncateRestrictionInvoker(fn)
 				if err != nil {
 					t.Fatalf("newTruncateRestrictionInvoker failed: %v", err)
 				}
-				trRest := trInvoker.Invoke(rt, test.elms)
+
+				trRest, err := trInvoker.Invoke(ctx, rt, test.elms)
+				if (err != nil) != test.wantErr {
+					t.Fatalf("trInvoker.Invoke(%v, %v) = %v, wantErr %v", rt, test.elms, err, test.wantErr)
+				}
+				if test.wantErr {
+					return
+				}
 
 				rsInvoker, err := newRestrictionSizeInvoker(rsFn)
 				if err != nil {
 					t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
 				}
-				_ = rsInvoker.Invoke(test.elms, trRest)
+				if _, err := rsInvoker.Invoke(ctx, test.elms, trRest); err != nil {
+					t.Fatalf("rsInvoker.Invoke(%v, %v) failed: %v", test.elms, trRest, err)
+				}
+
 				if !cmp.Equal(trRest, test.want) {
 					t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
 						test.elms, test.rest, trRest, test.want)
@@ -411,24 +542,35 @@ func TestInvokes(t *testing.T) {
 			ctFn := test.sdf.CreateTrackerFn()
 			rsFn := test.sdf.RestrictionSizeFn()
 			t.Run(test.name, func(t *testing.T) {
+				ctx := context.Background()
 				rest := test.rest // Create a copy because our test SDF edits the restriction.
 				ctInvoker, err := newCreateTrackerInvoker(ctFn)
 				if err != nil {
 					t.Fatalf("newCreateTrackerInvoker failed: %v", err)
 				}
-				rt := ctInvoker.Invoke(rest)
+				rt, err := ctInvoker.Invoke(ctx, rest)
+				if err != nil {
+					t.Fatalf("ctInvoker.Invoke(%v) failed: %v", rest, err)
+				}
 
 				trInvoker, err := newDefaultTruncateRestrictionInvoker()
 				if err != nil {
 					t.Fatalf("newTruncateRestrictionInvoker failed: %v", err)
 				}
-				trRest := trInvoker.Invoke(rt, test.elms)
+				trRest, err := trInvoker.Invoke(ctx, rt, test.elms)
+				if err != nil {
+					t.Fatalf("trInvoker.Invoke(%v, %v) failed: %v", rt, test.elms, err)
+				}
+
 				if trRest != nil {
 					rsInvoker, err := newRestrictionSizeInvoker(rsFn)
 					if err != nil {
 						t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
 					}
-					_ = rsInvoker.Invoke(test.elms, trRest)
+					if _, err := rsInvoker.Invoke(ctx, test.elms, trRest); err != nil {
+						t.Fatalf("rsInvoker.Invoke(%v, %v) failed: %v", test.elms, trRest, err)
+					}
+
 					if !cmp.Equal(trRest, test.want) {
 						t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
 							test.elms, test.rest, trRest, test.want)
@@ -732,3 +874,55 @@ func (fn *VetEmptyInitialSplitSdf) ProcessElement(rt *VetRTracker, i int, emit f
 	emit(rest)
 	return sdf.ResumeProcessingIn(1 * time.Second)
 }
+
+var errSdf = errors.New("SDF error")
+
+// VetCreateInitialRestrictionErrSdf is an SDF with a CreateInitialRestriction method
+// that returns a non-nil error.
+type VetCreateInitialRestrictionErrSdf struct {
+	VetSdf
+}
+
+func (fn *VetCreateInitialRestrictionErrSdf) CreateInitialRestriction(i int) (*VetRestriction, error) {
+	return nil, errSdf
+}
+
+// VetSplitRestrictionErrSdf is an SDF with a SplitRestriction method
+// that returns a non-nil error.
+type VetSplitRestrictionErrSdf struct {
+	VetSdf
+}
+
+func (fn *VetSplitRestrictionErrSdf) SplitRestriction(int, *VetRestriction) ([]*VetRestriction, error) {
+	return nil, errSdf
+}
+
+// VetRestrictionSizeErrSdf is an SDF with a RestrictionSize method
+// that returns a non-nil error.
+type VetRestrictionSizeErrSdf struct {
+	VetSdf
+}
+
+func (fn *VetRestrictionSizeErrSdf) RestrictionSize(int, *VetRestriction) (float64, error) {
+	return -1, errSdf
+}
+
+// VetCreateTrackerErrSdf is an SDF with a CreateTracker method
+// that returns a non-nil error.
+type VetCreateTrackerErrSdf struct {
+	VetSdf
+}
+
+func (fn *VetCreateTrackerErrSdf) CreateTracker(*VetRestriction) (*VetRTracker, error) {
+	return nil, errSdf
+}
+
+// VetTruncateRestrictionErrSdf is an SDF with a TruncateRestriction method
+// that returns a non-nil error.
+type VetTruncateRestrictionErrSdf struct {
+	VetSdf
+}
+
+func (fn *VetTruncateRestrictionErrSdf) TruncateRestriction(*VetRTracker, int) (*VetRestriction, error) {
+	return nil, errSdf
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index 414f28553a8..a0380796e86 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -1114,10 +1114,11 @@ func TestAsSplittableUnit(t *testing.T) {
 
 				// Call from SplittableUnit and check results.
 				su := SplittableUnit(node)
-				if err := node.Up(context.Background()); err != nil {
+				ctx := context.Background()
+				if err := node.Up(ctx); err != nil {
 					t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
 				}
-				gotPrimaries, gotResiduals, err := su.Split(test.frac)
+				gotPrimaries, gotResiduals, err := su.Split(ctx, test.frac)
 				if err != nil {
 					t.Fatalf("SplittableUnit.Split(%v) failed: %v", test.frac, err)
 				}
@@ -1184,10 +1185,11 @@ func TestAsSplittableUnit(t *testing.T) {
 
 				// Call from SplittableUnit and check results.
 				su := SplittableUnit(node)
-				if err := node.Up(context.Background()); err != nil {
+				ctx := context.Background()
+				if err := node.Up(ctx); err != nil {
 					t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
 				}
-				_, _, err := su.Split(0.5)
+				_, _, err := su.Split(ctx, 0.5)
 				if err == nil {
 					t.Errorf("SplittableUnit.Split(%v) was expected to fail.", test.in)
 				}
@@ -1251,10 +1253,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				node.currW = 0
 				// Call from SplittableUnit and check results.
 				su := SplittableUnit(node)
-				if err := node.Up(context.Background()); err != nil {
+				ctx := context.Background()
+				if err := node.Up(ctx); err != nil {
 					t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
 				}
-				gotResiduals, err := su.Checkpoint()
+				gotResiduals, err := su.Checkpoint(ctx)
 
 				if err != nil {
 					t.Fatalf("SplittableUnit.Checkpoint() returned error, got %v", err)
@@ -1401,7 +1404,7 @@ func TestMultiWindowProcessing(t *testing.T) {
 	// Split should hit window boundary between 2 and 3. We don't need to check
 	// the split result here, just the effects it has on currW and numW.
 	frac := 0.5
-	if _, _, err := su.Split(frac); err != nil {
+	if _, _, err := su.Split(context.Background(), frac); err != nil {
 		t.Errorf("Split(%v) failed with error: %v", frac, err)
 	}
 	if got, want := node.currW, blockW; got != want {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/config/config.go b/sdks/go/pkg/beam/runners/prism/internal/config/config.go
index fc2b68d092f..9c3bdd012bc 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/config/config.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/config/config.go
@@ -244,4 +244,4 @@ func (r *HandlerRegistry) GetVariant(name string) *Variant {
 		return nil
 	}
 	return &Variant{parent: r, name: name, handlers: vs.Handlers}
-}
\ No newline at end of file
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
index 3c7cae97397..7b553f6ad65 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
@@ -33,4 +33,4 @@ func Test_toUrn(t *testing.T) {
 	if got := quickUrn(pipepb.StandardPTransforms_PAR_DO); got != want {
 		t.Errorf("quickUrn(\"pipepb.StandardPTransforms_PAR_DO\") = %v, want %v", got, want)
 	}
-}
\ No newline at end of file
+}