You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/05/04 02:18:31 UTC

[beam] branch master updated: [BEAM-11106] Support drain in Go SDK (#17432)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a6ea1485b96 [BEAM-11106] Support drain in Go SDK (#17432)
a6ea1485b96 is described below

commit a6ea1485b96e4bc3cf056e7e5db984fbb61da9cc
Author: Ritesh Ghorse <ri...@gmail.com>
AuthorDate: Tue May 3 22:18:21 2022 -0400

    [BEAM-11106] Support drain in Go SDK (#17432)
---
 CHANGES.md                                         |   1 +
 sdks/go/pkg/beam/core/graph/fn.go                  |  58 +++++++-
 sdks/go/pkg/beam/core/graph/fn_test.go             |  47 ++++++
 sdks/go/pkg/beam/core/runtime/exec/sdf.go          | 115 +++++++++++++++
 sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go |  83 +++++++++++
 .../beam/core/runtime/exec/sdf_invokers_test.go    | 160 ++++++++++++++++++++-
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     | 104 ++++++++++++++
 sdks/go/pkg/beam/core/runtime/exec/translate.go    |   9 +-
 sdks/go/pkg/beam/core/runtime/genx/genx.go         |   5 +
 .../pkg/beam/core/runtime/graphx/translate_test.go |   1 +
 sdks/go/pkg/beam/core/sdf/lock.go                  |  12 ++
 sdks/go/pkg/beam/core/sdf/sdf.go                   |  12 ++
 .../beam/io/rtrackers/offsetrange/offsetrange.go   |   4 +
 13 files changed, 606 insertions(+), 5 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index cd8791567cd..a8683df52a1 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -66,6 +66,7 @@
 ## New Features / Improvements
 
 * 'Manage Clusters' JupyterLab extension added for users to configure usage of Dataproc clusters managed by Interactive Beam (Python) ([BEAM-14130](https://issues.apache.org/jira/browse/BEAM-14130)).
+* Pipeline drain support added for Go SDK ([BEAM-11106](https://issues.apache.org/jira/browse/BEAM-11106)). **Note: this feature is not yet fully validated and should be treated as experimental in this release.**
 * Go SDK users may now write self-checkpointing Splittable DoFns to read from streaming sources. **Note: this feature is not yet fully validated and should be treated as experimental in this release.** ([BEAM-11104](https://issues.apache.org/jira/browse/BEAM-11104))
 
 ## Breaking Changes
diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 5dba9689a29..931458d922d 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -167,6 +167,7 @@ const (
 	splitRestrictionName         = "SplitRestriction"
 	restrictionSizeName          = "RestrictionSize"
 	createTrackerName            = "CreateTracker"
+	truncateRestrictionName      = "TruncateRestriction"
 
 	createWatermarkEstimatorName       = "CreateWatermarkEstimator"
 	initialWatermarkEstimatorStateName = "InitialWatermarkEstimatorState"
@@ -192,6 +193,7 @@ var doFnNames = []string{
 	restrictionSizeName,
 	createTrackerName,
 	createWatermarkEstimatorName,
+	truncateRestrictionName,
 	initialWatermarkEstimatorStateName,
 	watermarkEstimatorStateName,
 }
@@ -203,6 +205,12 @@ var requiredSdfNames = []string{
 	createTrackerName,
 }
 
+var optionalSdfNames = []string{
+	truncateRestrictionName,
+}
+
+var sdfNames = append(append([]string{}, requiredSdfNames...), optionalSdfNames...)
+
 var watermarkEstimationNames = []string{
 	createWatermarkEstimatorName,
 	initialWatermarkEstimatorStateName,
@@ -314,6 +322,17 @@ func (f *SplittableDoFn) RestrictionT() reflect.Type {
 	return f.CreateInitialRestrictionFn().Ret[0].T
 }
 
+// HasTruncateRestriction returns whether the DoFn implements a custom truncate restriction function.
+func (f *SplittableDoFn) HasTruncateRestriction() bool {
+	_, ok := f.methods[truncateRestrictionName]
+	return ok
+}
+
+// TruncateRestrictionFn returns the "TruncateRestriction" function, if present.
+func (f *SplittableDoFn) TruncateRestrictionFn() *funcx.Fn {
+	return f.methods[truncateRestrictionName]
+}
+
 // IsWatermarkEstimating returns whether the DoFn implements a custom watermark estimator.
 func (f *SplittableDoFn) IsWatermarkEstimating() bool {
 	_, ok := f.methods[createWatermarkEstimatorName]
@@ -843,11 +862,18 @@ func validateSdfSigNumbers(fn *Fn, num int) error {
 		splitRestrictionName:         num + 1,
 		restrictionSizeName:          num + 1,
 		createTrackerName:            1,
+		truncateRestrictionName:      num + 1,
+	}
+	optionalSdfs := map[string]bool{
+		truncateRestrictionName: true,
 	}
 	returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF methods.
 
-	for _, name := range requiredSdfNames {
-		method := fn.methods[name]
+	for _, name := range sdfNames {
+		method, ok := fn.methods[name]
+		if !ok && optionalSdfs[name] {
+			continue
+		}
 		if len(method.Param) != paramNums[name] {
 			err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v",
 				name, len(method.Param), paramNums[name])
@@ -950,6 +976,34 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 		}
 	}
 
+	rTrackerImplT := fn.methods[createTrackerName].Ret[0].T
+
+	for _, name := range optionalSdfNames {
+		method, ok := fn.methods[name]
+		if !ok {
+			continue
+		}
+		switch name {
+		case truncateRestrictionName:
+			if method.Param[0].T != rTrackerImplT {
+				err := errors.Errorf("mismatched restriction tracker type in method %v, param %v. got: %v, want: %v",
+					truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT)
+				return errors.SetTopLevelMsgf(err, "Mismatched restriction tracker type in method %v, "+
+					"parameter at index %v. Got: %v, Want: %v (from method %v). "+
+					"Ensure that restriction tracker is the first parameter.",
+					truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT, createTrackerName)
+			}
+			if method.Ret[0].T != restrictionT {
+				err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v",
+					truncateRestrictionName, 0, method.Ret[0].T, restrictionT)
+				return errors.SetTopLevelMsgf(err, "Invalid output type in method %v, "+
+					"return value at index %v. Got: %v, Want: %v (from method %v). "+
+					"Ensure that all restrictions in an SDF are the same type.",
+					truncateRestrictionName, 0, method.Ret[0].T, restrictionT, createInitialRestrictionName)
+			}
+		}
+	}
+
 	return nil
 }
 
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index efc25bea623..5f0e274b084 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -183,11 +183,13 @@ func TestNewDoFnSdf(t *testing.T) {
 			{dfn: &BadSdfParamsSplitRest{}},
 			{dfn: &BadSdfParamsRestSize{}},
 			{dfn: &BadSdfParamsCreateTracker{}},
+			{dfn: &BadSdfParamsTruncateRestriction{}},
 			// Validate return numbers.
 			{dfn: &BadSdfReturnsCreateRest{}},
 			{dfn: &BadSdfReturnsSplitRest{}},
 			{dfn: &BadSdfReturnsRestSize{}},
 			{dfn: &BadSdfReturnsCreateTracker{}},
+			{dfn: &BadSdfReturnsTruncateRestriction{}},
 			// Validate element types consistent with ProcessElement.
 			{dfn: &BadSdfElementTCreateRest{}},
 			{dfn: &BadSdfElementTSplitRest{}},
@@ -197,11 +199,13 @@ func TestNewDoFnSdf(t *testing.T) {
 			{dfn: &BadSdfRestTSplitRestReturn{}},
 			{dfn: &BadSdfRestTRestSize{}},
 			{dfn: &BadSdfRestTCreateTracker{}},
+			{dfn: &BadSdfRestTTruncateRestriction{}},
 			// Validate other types
 			{dfn: &BadSdfRestSizeReturn{}},
 			{dfn: &BadSdfCreateTrackerReturn{}},
 			{dfn: &BadSdfMismatchedRTracker{}},
 			{dfn: &BadSdfMissingRTracker{}},
+			{dfn: &BadSdfMismatchRTrackerTruncateRestriction{}},
 		}
 		for _, test := range tests {
 			t.Run(reflect.TypeOf(test.dfn).String(), func(t *testing.T) {
@@ -709,6 +713,9 @@ func (rt *RTrackerT) IsDone() bool {
 func (rt *RTrackerT) GetRestriction() interface{} {
 	return nil
 }
+func (rt *RTrackerT) IsBounded() bool {
+	return true
+}
 
 type RTracker2T struct{}
 
@@ -755,6 +762,10 @@ func (fn *GoodSdf) ProcessElement(*RTrackerT, int) int {
 	return 0
 }
 
+func (fn *GoodSdf) TruncateRestriction(*RTrackerT, int) RestT {
+	return RestT{}
+}
+
 type GoodSdfKv struct {
 	*GoodDoFnKv
 }
@@ -779,6 +790,10 @@ func (fn *GoodSdfKv) ProcessElement(*RTrackerT, int, int) int {
 	return 0
 }
 
+func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT {
+	return RestT{}
+}
+
 type WatermarkEstimatorT struct{}
 
 func (e *WatermarkEstimatorT) CurrentWatermark() time.Time {
@@ -914,6 +929,14 @@ func (fn *BadSdfParamsCreateTracker) CreateTracker(int, RestT) *RTrackerT {
 	return &RTrackerT{}
 }
 
+type BadSdfParamsTruncateRestriction struct {
+	*GoodSdf
+}
+
+func (fn *BadSdfParamsTruncateRestriction) TruncateRestriction(*RTrackerT, int, int) RestT {
+	return RestT{}
+}
+
 // Examples with invalid numbers of return values.
 
 type BadSdfReturnsCreateRest struct {
@@ -948,6 +971,14 @@ func (fn *BadSdfReturnsCreateTracker) CreateTracker(RestT) (*RTrackerT, int) {
 	return &RTrackerT{}, 0
 }
 
+type BadSdfReturnsTruncateRestriction struct {
+	*GoodSdf
+}
+
+func (fn *BadSdfReturnsTruncateRestriction) TruncateRestriction(*RTrackerT, int) (RestT, int) {
+	return RestT{}, 0
+}
+
 // Examples with element types inconsistent with ProcessElement.
 
 type BadSdfElementTCreateRest struct {
@@ -1010,6 +1041,14 @@ func (fn *BadSdfRestTCreateTracker) CreateTracker(BadRestT) *RTrackerT {
 	return &RTrackerT{}
 }
 
+type BadSdfRestTTruncateRestriction struct {
+	*GoodSdf
+}
+
+func (fn *BadSdfRestTTruncateRestriction) TruncateRestriction(*RTrackerT, int) BadRestT {
+	return BadRestT{}
+}
+
 type BadWatermarkEstimatingNonSdf struct {
 	*GoodDoFn
 }
@@ -1058,6 +1097,14 @@ func (fn *BadSdfMismatchedRTracker) ProcessElement(*OtherRTrackerT, int) int {
 	return 0
 }
 
+type BadSdfMismatchRTrackerTruncateRestriction struct {
+	*GoodSdf
+}
+
+func (fn *BadSdfMismatchRTrackerTruncateRestriction) TruncateRestriction(*OtherRTrackerT, int) RestT {
+	return RestT{}
+}
+
 type BadWatermarkEstimatingCreateWatermarkEstimatorReturnType struct {
 	*GoodWatermarkEstimating
 }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index ec457213e86..f5dd6a7431b 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -232,6 +232,121 @@ func (n *SplitAndSizeRestrictions) String() string {
 	return fmt.Sprintf("SDF.SplitAndSizeRestrictions[%v] UID:%v Out:%v", path.Base(n.Fn.Name()), n.UID, IDs(n.Out))
 }
 
+// TruncateSizedRestriction is an executor for the expanded SDF step of the
+// same name. This step is added to the expanded SDF when the runner signals to drain
+// the pipeline. This step is followed by ProcessSizedElementsAndRestrictions.
+type TruncateSizedRestriction struct {
+	UID         UnitID
+	Fn          *graph.DoFn
+	Out         Node
+	truncateInv *trInvoker
+	sizeInv     *rsInvoker
+	ctInv       *ctInvoker
+}
+
+// ID return the UnitID for this unit.
+func (n *TruncateSizedRestriction) ID() UnitID {
+	return n.UID
+}
+
+// Up performs one-time setup for this executor.
+func (n *TruncateSizedRestriction) Up(ctx context.Context) error {
+	fn := (*graph.SplittableDoFn)(n.Fn).CreateTrackerFn()
+	var err error
+	if n.ctInv, err = newCreateTrackerInvoker(fn); err != nil {
+		return errors.WithContextf(err, "%v", n)
+	}
+
+	fn = (*graph.SplittableDoFn)(n.Fn).TruncateRestrictionFn()
+	if fn != nil {
+		if n.truncateInv, err = newTruncateRestrictionInvoker(fn); err != nil {
+			return err
+		}
+	} else {
+		if n.truncateInv, err = newDefaultTruncateRestrictionInvoker(); err != nil {
+			return err
+		}
+	}
+	fn = (*graph.SplittableDoFn)(n.Fn).RestrictionSizeFn()
+	if n.sizeInv, err = newRestrictionSizeInvoker(fn); err != nil {
+		return err
+	}
+	return nil
+}
+
+// StartBundle currently does nothing.
+func (n *TruncateSizedRestriction) StartBundle(ctx context.Context, id string, data DataContext) error {
+	return n.Out.StartBundle(ctx, id, data)
+}
+
+// ProcessElement gets input elm as:
+// Input Diagram:
+//   *FullValue {
+//     Elm: *FullValue {
+//       Elm:  *FullValue (original input)
+//       Elm2: *FullValue {
+// 	       Elm: Restriction
+// 	       Elm2: Watermark estimator state
+//       }
+//     }
+//     Elm2: float64 (size)
+//     Windows
+//     Timestamps
+//    }
+//
+// Output Diagram:
+//   *FullValue {
+//     Elm: *FullValue {
+//       Elm:  *FullValue (original input)
+//       Elm2: *FullValue {
+// 	       Elm: Restriction
+// 	       Elm2: Watermark estimator state
+//       }
+//     }
+//     Elm2: float64 (size)
+//     Windows
+//     Timestamps
+//    }
+func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
+	mainElm := elm.Elm.(*FullValue).Elm.(*FullValue)
+	rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
+	rt := n.ctInv.Invoke(rest)
+	newRest := n.truncateInv.Invoke(rt, mainElm)
+	if newRest == nil {
+		// do not propagate discarded restrictions.
+		return nil
+	}
+	size := n.sizeInv.Invoke(mainElm, newRest)
+	output := &FullValue{}
+	output.Timestamp = elm.Timestamp
+	output.Windows = elm.Windows
+	output.Elm = &FullValue{Elm: mainElm, Elm2: &FullValue{Elm: newRest, Elm2: elm.Elm.(*FullValue).Elm2.(*FullValue).Elm2}}
+	output.Elm2 = size
+
+	if err := n.Out.ProcessElement(ctx, output, values...); err != nil {
+		return err
+	}
+	return nil
+}
+
+// FinishBundle resets the invokers.
+func (n *TruncateSizedRestriction) FinishBundle(ctx context.Context) error {
+	n.truncateInv.Reset()
+	n.sizeInv.Reset()
+	n.ctInv.Reset()
+	return n.Out.FinishBundle(ctx)
+}
+
+// Down currently does nothing.
+func (n *TruncateSizedRestriction) Down(_ context.Context) error {
+	return nil
+}
+
+// String outputs a human-readable description of this transform.
+func (n *TruncateSizedRestriction) String() string {
+	return fmt.Sprintf("SDF.TruncateSizedRestriction[%v] UID:%v Out:%v", path.Base(n.Fn.Name()), n.UID, IDs(n.Out))
+}
+
 // ProcessSizedElementsAndRestrictions is an executor for the expanded SDF step
 // of the same name. It is the final step of the expanded SDF. It sets up and
 // invokes the user's SDF methods, similar to exec.ParDo but with slight
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
index f90cdd0abf5..079af1443cd 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
@@ -301,6 +301,89 @@ func (n *ctInvoker) Reset() {
 	}
 }
 
+// trInvoker is an invoker for TruncateRestriction.
+type trInvoker struct {
+	fn   *funcx.Fn
+	args []interface{}
+	call func(rest interface{}, elms *FullValue) (pair interface{})
+}
+
+func defaultTruncateRestriction(restTracker interface{}) (newRest interface{}) {
+	if tracker, ok := restTracker.(sdf.BoundableRTracker); ok && !tracker.IsBounded() {
+		return nil
+	}
+	return restTracker.(sdf.RTracker).GetRestriction()
+}
+
+func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) {
+	n := &trInvoker{
+		fn:   fn,
+		args: make([]interface{}, len(fn.Param)),
+	}
+	if err := n.initCallFn(); err != nil {
+		return nil, errors.WithContext(err, "sdf TruncateRestriction invoker")
+	}
+	return n, nil
+}
+
+func newDefaultTruncateRestrictionInvoker() (*trInvoker, error) {
+	n := &trInvoker{}
+	n.call = func(rest interface{}, elms *FullValue) interface{} {
+		return defaultTruncateRestriction(rest)
+	}
+	return n, nil
+}
+
+func (n *trInvoker) initCallFn() error {
+	// Expects a signature of the form:
+	// (key?, value, restriction) []restriction
+	// TODO(BEAM-9643): Link to full documentation.
+	switch fnT := n.fn.Fn.(type) {
+	case reflectx.Func2x1:
+		n.call = func(rest interface{}, elms *FullValue) interface{} {
+			return fnT.Call2x1(rest, elms.Elm)
+		}
+	case reflectx.Func3x1:
+		n.call = func(rest interface{}, elms *FullValue) interface{} {
+			return fnT.Call3x1(rest, elms.Elm, elms.Elm2)
+		}
+	default:
+		switch len(n.fn.Param) {
+		case 2:
+			n.call = func(rest interface{}, elms *FullValue) interface{} {
+				n.args[0] = rest
+				n.args[1] = elms.Elm
+				return n.fn.Fn.Call(n.args)[0]
+			}
+		case 3:
+			n.call = func(rest interface{}, elms *FullValue) interface{} {
+				n.args[0] = rest
+				n.args[1] = elms.Elm
+				n.args[2] = elms.Elm2
+				return n.fn.Fn.Call(n.args)[0]
+			}
+		default:
+			return errors.Errorf("TruncateRestriction fn %v has unexpected number of parameters: %v",
+				n.fn.Fn.Name(), len(n.fn.Param))
+		}
+	}
+	return nil
+}
+
+// Invoke calls TruncateRestriction given a FullValue containing an element and
+// the associated restriction tracker, and returns a truncated restriction.
+func (n *trInvoker) Invoke(rt interface{}, elms *FullValue) (rest interface{}) {
+	return n.call(rt, elms)
+}
+
+// Reset zeroes argument entries in the cached slice to allow values to be
+// garbage collected after the bundle ends.
+func (n *trInvoker) Reset() {
+	for i := range n.args {
+		n.args[i] = nil
+	}
+}
+
 // cweInvoker is an invoker for CreateWatermarkEstimator.
 type cweInvoker struct {
 	fn   *funcx.Fn
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
index bf959d8ee01..de016a327a8 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
@@ -306,6 +306,142 @@ func TestInvokes(t *testing.T) {
 			t.Errorf("Invoke() has incorrect output: got: %v, want: %v", got, want)
 		}
 	})
+
+	t.Run("TruncateRestriction Invoker (trInvoker)", func(t *testing.T) {
+		tests := []struct {
+			name string
+			sdf  *graph.SplittableDoFn
+			elms *FullValue
+			rest *VetRestriction
+			want interface{}
+		}{
+			{
+				name: "SingleElem",
+				sdf:  sdf,
+				elms: &FullValue{Elm: 1},
+				rest: &VetRestriction{ID: "Sdf"},
+				want: &VetRestriction{ID: "Sdf", CreateTracker: true, TruncateRest: true, RestSize: true, Val: 1},
+			}, {
+				name: "KvElem",
+				sdf:  kvsdf,
+				elms: &FullValue{Elm: 1, Elm2: 2},
+				rest: &VetRestriction{ID: "KvSdf"},
+				want: &VetRestriction{ID: "KvSdf", CreateTracker: true, TruncateRest: true, RestSize: true, Key: 1, Val: 2},
+			},
+		}
+		for _, test := range tests {
+			test := test
+			fn := test.sdf.TruncateRestrictionFn()
+			ctFn := test.sdf.CreateTrackerFn()
+			rsFn := test.sdf.RestrictionSizeFn()
+			t.Run(test.name, func(t *testing.T) {
+				rest := test.rest // Create a copy because our test SDF edits the restriction.
+				ctInvoker, err := newCreateTrackerInvoker(ctFn)
+				if err != nil {
+					t.Fatalf("newCreateTrackerInvoker failed: %v", err)
+				}
+				rt := ctInvoker.Invoke(rest)
+
+				trInvoker, err := newTruncateRestrictionInvoker(fn)
+				if err != nil {
+					t.Fatalf("newTruncateRestrictionInvoker failed: %v", err)
+				}
+				trRest := trInvoker.Invoke(rt, test.elms)
+
+				rsInvoker, err := newRestrictionSizeInvoker(rsFn)
+				if err != nil {
+					t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
+				}
+				_ = rsInvoker.Invoke(test.elms, trRest)
+				if !cmp.Equal(trRest, test.want) {
+					t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
+						test.elms, test.rest, trRest, test.want)
+				}
+				trInvoker.Reset()
+				for i, arg := range trInvoker.args {
+					if arg != nil {
+						t.Errorf("Reset() failed to empty all args. args[%v] = %v", i, arg)
+					}
+				}
+			})
+		}
+	})
+
+	t.Run("Default TruncateRestriction Invoker", func(t *testing.T) {
+		tests := []struct {
+			name string
+			sdf  *graph.SplittableDoFn
+			elms *FullValue
+			rest *VetRestriction
+			want interface{}
+		}{
+			{
+				name: "SingleElem",
+				sdf:  sdf,
+				elms: &FullValue{Elm: 1},
+				rest: &VetRestriction{ID: "Sdf", Bounded: true},
+				want: &VetRestriction{ID: "Sdf", Bounded: true, CreateTracker: true, RestSize: true, Val: 1},
+			},
+			{
+				name: "SingleElem",
+				sdf:  sdf,
+				elms: &FullValue{Elm: 1},
+				rest: &VetRestriction{ID: "Sdf", Bounded: false},
+				want: &VetRestriction{ID: "Sdf", Bounded: false, CreateTracker: true, RestSize: false, Val: 1},
+			},
+			{
+				name: "KvElem",
+				sdf:  kvsdf,
+				elms: &FullValue{Elm: 1, Elm2: 2},
+				rest: &VetRestriction{ID: "KvSdf", Bounded: true},
+				want: &VetRestriction{ID: "KvSdf", Bounded: true, CreateTracker: true, RestSize: true, Key: 1, Val: 2},
+			},
+			{
+				name: "KvElem",
+				sdf:  kvsdf,
+				elms: &FullValue{Elm: 1, Elm2: 2},
+				rest: &VetRestriction{ID: "KvSdf", Bounded: false},
+				want: &VetRestriction{ID: "KvSdf", Bounded: false, CreateTracker: true, RestSize: false, Key: 1, Val: 2},
+			},
+		}
+
+		for _, test := range tests {
+			test := test
+			ctFn := test.sdf.CreateTrackerFn()
+			rsFn := test.sdf.RestrictionSizeFn()
+			t.Run(test.name, func(t *testing.T) {
+				rest := test.rest // Create a copy because our test SDF edits the restriction.
+				ctInvoker, err := newCreateTrackerInvoker(ctFn)
+				if err != nil {
+					t.Fatalf("newCreateTrackerInvoker failed: %v", err)
+				}
+				rt := ctInvoker.Invoke(rest)
+
+				trInvoker, err := newDefaultTruncateRestrictionInvoker()
+				if err != nil {
+					t.Fatalf("newTruncateRestrictionInvoker failed: %v", err)
+				}
+				trRest := trInvoker.Invoke(rt, test.elms)
+				if trRest != nil {
+					rsInvoker, err := newRestrictionSizeInvoker(rsFn)
+					if err != nil {
+						t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
+					}
+					_ = rsInvoker.Invoke(test.elms, trRest)
+					if !cmp.Equal(trRest, test.want) {
+						t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v",
+							test.elms, test.rest, trRest, test.want)
+					}
+					trInvoker.Reset()
+					for i, arg := range trInvoker.args {
+						if arg != nil {
+							t.Errorf("Reset() failed to empty all args. args[%v] = %v", i, arg)
+						}
+					}
+				}
+			})
+		}
+	})
 }
 
 // VetRestriction is a restriction used for validating that SDF methods get
@@ -322,9 +458,12 @@ type VetRestriction struct {
 	// confirm that the restriction saw the expected element.
 	Key, Val interface{}
 
+	// Bounded just tells if the restriction is bounded or not
+	Bounded bool
+
 	// These booleans should be flipped to true by the corresponding SDF methods
 	// to prove that the methods got called on the restriction.
-	CreateRest, SplitRest, RestSize, CreateTracker, ProcessElm bool
+	CreateRest, SplitRest, RestSize, CreateTracker, ProcessElm, TruncateRest bool
 }
 
 func (r VetRestriction) copy() VetRestriction {
@@ -345,6 +484,7 @@ func (rt *VetRTracker) GetRestriction() interface{}     { return nil }
 func (rt *VetRTracker) TrySplit(_ float64) (interface{}, interface{}, error) {
 	return nil, nil, nil
 }
+func (rt *VetRTracker) IsBounded() bool { return rt.Rest.Bounded }
 
 type VetWatermarkEstimator struct {
 	State int
@@ -391,6 +531,12 @@ func (fn *VetSdf) RestrictionSize(i int, rest *VetRestriction) float64 {
 	return (float64)(i)
 }
 
+// TruncateRestriction truncates the restriction into half.
+func (fn *VetSdf) TruncateRestriction(rest *VetRTracker, i int) *VetRestriction {
+	rest.Rest.TruncateRest = true
+	return rest.Rest
+}
+
 // CreateTracker creates an RTracker containing the given restriction and flips
 // the appropriate flags on the restriction to track that this was called.
 func (fn *VetSdf) CreateTracker(rest *VetRestriction) *VetRTracker {
@@ -505,6 +651,12 @@ func (fn *VetKvSdf) RestrictionSize(i, j int, rest *VetRestriction) float64 {
 	return (float64)(i + j)
 }
 
+// TruncateRestriction truncates the restriction tracked by VetRTracker.
+func (fn *VetKvSdf) TruncateRestriction(rest *VetRTracker, i, j int) *VetRestriction {
+	rest.Rest.TruncateRest = true
+	return rest.Rest
+}
+
 // CreateTracker creates an RTracker containing the given restriction and flips
 // the appropriate flags on the restriction to track that this was called.
 func (fn *VetKvSdf) CreateTracker(rest *VetRestriction) *VetRTracker {
@@ -552,6 +704,12 @@ func (fn *VetEmptyInitialSplitSdf) RestrictionSize(i int, rest *VetRestriction)
 	return (float64)(i)
 }
 
+// TruncateRestriction truncates the restriction into half.
+func (fn *VetEmptyInitialSplitSdf) TruncateRestriction(rest *VetRTracker, i int) *VetRestriction {
+	rest.Rest.TruncateRest = true
+	return rest.Rest
+}
+
 // CreateTracker creates an RTracker containing the given restriction and flips
 // the appropriate flags on the restriction to track that this was called.
 func (fn *VetEmptyInitialSplitSdf) CreateTracker(rest *VetRestriction) *VetRTracker {
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index 1f9c56bdbac..8955d970707 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -560,6 +560,110 @@ func TestSdfNodes(t *testing.T) {
 			})
 		}
 	})
+
+	// Validate TruncateSizedRestriction matches its contract and properly
+	// invokes SDF methods TruncateRestriction and RestrictionSize.
+	t.Run("TruncateSizedRestriction", func(t *testing.T) {
+		tests := []struct {
+			name string
+			fn   *graph.DoFn
+			in   FullValue
+			want []FullValue
+		}{
+			{
+				name: "SingleElem",
+				fn:   dfn,
+				in: FullValue{
+					Elm: &FullValue{
+						Elm: &FullValue{
+							Elm:  1,
+							Elm2: nil,
+						},
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: nil,
+						},
+					},
+					Elm2:      1.0,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+				want: []FullValue{
+					{
+						Elm: &FullValue{
+							Elm: &FullValue{
+								Elm:  1,
+								Elm2: nil,
+							},
+							Elm2: &FullValue{
+								Elm:  &VetRestriction{ID: "Sdf", CreateTracker: true, TruncateRest: true, RestSize: true, Val: 1},
+								Elm2: nil,
+							},
+						},
+						Elm2:      1.0,
+						Timestamp: testTimestamp,
+						Windows:   testWindows,
+					},
+				},
+			},
+			{
+				name: "KvElem",
+				fn:   kvdfn,
+				in: FullValue{
+					Elm: &FullValue{
+						Elm: &FullValue{
+							Elm:       1,
+							Elm2:      2,
+							Timestamp: testTimestamp,
+							Windows:   testWindows,
+						},
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "KvSdf"},
+							Elm2: nil,
+						},
+					},
+					Elm2:      3.0,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+				want: []FullValue{
+					{
+						Elm: &FullValue{
+							Elm: &FullValue{
+								Elm:       1,
+								Elm2:      2,
+								Timestamp: testTimestamp,
+								Windows:   testWindows,
+							},
+							Elm2: &FullValue{
+								Elm:  &VetRestriction{ID: "KvSdf", CreateTracker: true, TruncateRest: true, RestSize: true, Key: 1, Val: 2},
+								Elm2: nil,
+							},
+						},
+						Elm2:      3.0,
+						Timestamp: testTimestamp,
+						Windows:   testWindows,
+					},
+				},
+			},
+		}
+		for _, test := range tests {
+			test := test
+			t.Run(test.name, func(t *testing.T) {
+				capt := &CaptureNode{UID: 2}
+				node := &TruncateSizedRestriction{UID: 1, Fn: test.fn, Out: capt}
+				root := &FixedRoot{UID: 0, Elements: []MainInput{{Key: test.in}}, Out: node}
+				units := []Unit{root, node, capt}
+				constructAndExecutePlan(t, units)
+
+				got := capt.Elements
+				if !cmp.Equal(got, test.want) {
+					t.Errorf("TruncateSizedRestriction(%v) has incorrect output: got: %v, want: %v",
+						test.in, got, test.want)
+				}
+			})
+		}
+	})
 }
 
 // TestAsSplittableUnit tests ProcessSizedElementsAndRestrictions' implementation
diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go
index 5eaac7c5644..c5b1cd0954a 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go
@@ -46,6 +46,7 @@ const (
 	urnPairWithRestriction                 = "beam:transform:sdf_pair_with_restriction:v1"
 	urnSplitAndSizeRestrictions            = "beam:transform:sdf_split_and_size_restrictions:v1"
 	urnProcessSizedElementsAndRestrictions = "beam:transform:sdf_process_sized_element_and_restrictions:v1"
+	urnTruncateSizedRestrictions           = "beam:transform:sdf_truncate_sized_restrictions:v1"
 )
 
 // UnmarshalPlan converts a model bundle descriptor into an execution Plan.
@@ -400,14 +401,16 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
 		urnPerKeyCombineConvert,
 		urnPairWithRestriction,
 		urnSplitAndSizeRestrictions,
-		urnProcessSizedElementsAndRestrictions:
+		urnProcessSizedElementsAndRestrictions,
+		urnTruncateSizedRestrictions:
 		var data string
 		var sides map[string]*pipepb.SideInput
 		switch urn {
 		case graphx.URNParDo,
 			urnPairWithRestriction,
 			urnSplitAndSizeRestrictions,
-			urnProcessSizedElementsAndRestrictions:
+			urnProcessSizedElementsAndRestrictions,
+			urnTruncateSizedRestrictions:
 			var pardo pipepb.ParDoPayload
 			if err := proto.Unmarshal(payload, &pardo); err != nil {
 				return nil, errors.Wrapf(err, "invalid ParDo payload for %v", transform)
@@ -453,6 +456,8 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
 					u = &PairWithRestriction{UID: b.idgen.New(), Fn: dofn, Out: out[0]}
 				case urnSplitAndSizeRestrictions:
 					u = &SplitAndSizeRestrictions{UID: b.idgen.New(), Fn: dofn, Out: out[0]}
+				case urnTruncateSizedRestrictions:
+					u = &TruncateSizedRestriction{UID: b.idgen.New(), Fn: dofn, Out: out[0]}
 				default:
 					n := &ParDo{UID: b.idgen.New(), Fn: dofn, Inbound: in, Out: out}
 					n.PID = transform.GetUniqueName()
diff --git a/sdks/go/pkg/beam/core/runtime/genx/genx.go b/sdks/go/pkg/beam/core/runtime/genx/genx.go
index b5bf766c268..19f9e7acbc2 100644
--- a/sdks/go/pkg/beam/core/runtime/genx/genx.go
+++ b/sdks/go/pkg/beam/core/runtime/genx/genx.go
@@ -112,6 +112,11 @@ func handleDoFn(fn *graph.DoFn, c cache) {
 	c.pullMethod(sdf.RestrictionSizeFn())
 	c.pullMethod(sdf.SplitRestrictionFn())
 	c.regType(sdf.RestrictionT())
+
+	if sdf.HasTruncateRestriction() {
+		c.pullMethod(sdf.TruncateRestrictionFn())
+	}
+
 	if !sdf.IsWatermarkEstimating() {
 		return
 	}
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate_test.go b/sdks/go/pkg/beam/core/runtime/graphx/translate_test.go
index de481856aba..dec418154c5 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/translate_test.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/translate_test.go
@@ -201,6 +201,7 @@ func (rt *testRT) GetError() error                 { return nil }
 func (rt *testRT) GetProgress() (float64, float64) { return 0, 0 }
 func (rt *testRT) IsDone() bool                    { return true }
 func (rt *testRT) GetRestriction() interface{}     { return nil }
+func (rt *testRT) IsBounded() bool                 { return true }
 func (rt *testRT) TrySplit(_ float64) (interface{}, interface{}, error) {
 	return nil, nil, nil
 }
diff --git a/sdks/go/pkg/beam/core/sdf/lock.go b/sdks/go/pkg/beam/core/sdf/lock.go
index d8f20534f0c..8a5d6448ee3 100644
--- a/sdks/go/pkg/beam/core/sdf/lock.go
+++ b/sdks/go/pkg/beam/core/sdf/lock.go
@@ -80,3 +80,15 @@ func (rt *LockRTracker) GetRestriction() interface{} {
 	defer rt.Mu.Unlock()
 	return rt.Rt.GetRestriction()
 }
+
+// IsBounded locks a mutex for thread safety, and then delegates to the
+// underlying tracker's IsBounded(). If BoundableRTracker is not implemented
+// then the RTracker is considered to be bounded by default.
+func (rt *LockRTracker) IsBounded() bool {
+	rt.Mu.Lock()
+	defer rt.Mu.Unlock()
+	if tracker, ok := rt.Rt.(BoundableRTracker); ok {
+		return tracker.IsBounded()
+	}
+	return true
+}
diff --git a/sdks/go/pkg/beam/core/sdf/sdf.go b/sdks/go/pkg/beam/core/sdf/sdf.go
index 8188ffb86c5..2876d5985a2 100644
--- a/sdks/go/pkg/beam/core/sdf/sdf.go
+++ b/sdks/go/pkg/beam/core/sdf/sdf.go
@@ -91,6 +91,18 @@ type RTracker interface {
 	GetRestriction() interface{}
 }
 
+// BoundableRTracker is an interface used to interact with restrictions that may be bounded or unbounded
+// while processing elements in splittable DoFns (specifically, in the ProcessElement method and TruncateRestriction method).
+// Each BoundableRTracker tracks the progress of a single restriction.
+//
+// All BoundableRTracker methods should be thread-safe for dynamic splits to function correctly.
+type BoundableRTracker interface {
+	RTracker
+	// IsBounded returns the boundedness of the current restriction. If the current restriction represents a
+	// finite amount of work, it should return true. Otherwise, it should return false.
+	IsBounded() bool
+}
+
 // WatermarkEstimator is an interface used to represent a user defined watermark estimator.
 // Watermark estimators allow users to advance the output watermark of the current sdf.
 type WatermarkEstimator interface {
diff --git a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
index a492bcddc74..743891e8055 100644
--- a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
+++ b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
@@ -217,3 +217,7 @@ func (tracker *Tracker) IsDone() bool {
 func (tracker *Tracker) GetRestriction() interface{} {
 	return tracker.rest
 }
+
+func (tracker *Tracker) IsBounded() bool {
+	return true
+}