You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/04/27 04:03:30 UTC

[beam] branch master updated: [BEAM-11105] Stateful watermark estimation (#17374)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e91f92e678e [BEAM-11105] Stateful watermark estimation (#17374)
e91f92e678e is described below

commit e91f92e678ec589888bb0d691039c57e3aa88c88
Author: Danny McCormick <da...@google.com>
AuthorDate: Wed Apr 27 00:03:22 2022 -0400

    [BEAM-11105] Stateful watermark estimation (#17374)
---
 sdks/go/pkg/beam/core/graph/fn.go                  | 217 +++++++++++---
 sdks/go/pkg/beam/core/graph/fn_test.go             | 172 ++++++++++-
 .../go/pkg/beam/core/runtime/exec/dynsplit_test.go |  14 +-
 sdks/go/pkg/beam/core/runtime/exec/sdf.go          | 126 +++++---
 sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go | 173 ++++++++++-
 .../beam/core/runtime/exec/sdf_invokers_test.go    | 137 ++++++++-
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     | 320 +++++++++++++++++----
 sdks/go/pkg/beam/core/runtime/genx/genx.go         |   6 +
 sdks/go/pkg/beam/core/runtime/genx/genx_test.go    |  16 +-
 sdks/go/pkg/beam/pardo.go                          |  11 +-
 10 files changed, 1018 insertions(+), 174 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 69460ee9f4a..775a3dfe9f3 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -168,7 +168,9 @@ const (
 	restrictionSizeName          = "RestrictionSize"
 	createTrackerName            = "CreateTracker"
 
-	createWatermarkEstimatorName = "CreateWatermarkEstimator"
+	createWatermarkEstimatorName       = "CreateWatermarkEstimator"
+	initialWatermarkEstimatorStateName = "InitialWatermarkEstimatorState"
+	watermarkEstimatorStateName        = "WatermarkEstimatorState"
 
 	createAccumulatorName = "CreateAccumulator"
 	addInputName          = "AddInput"
@@ -190,6 +192,8 @@ var doFnNames = []string{
 	restrictionSizeName,
 	createTrackerName,
 	createWatermarkEstimatorName,
+	initialWatermarkEstimatorStateName,
+	watermarkEstimatorStateName,
 }
 
 var requiredSdfNames = []string{
@@ -201,6 +205,8 @@ var requiredSdfNames = []string{
 
 var watermarkEstimationNames = []string{
 	createWatermarkEstimatorName,
+	initialWatermarkEstimatorStateName,
+	watermarkEstimatorStateName,
 }
 
 var combineFnNames = []string{
@@ -314,7 +320,7 @@ func (f *SplittableDoFn) IsWatermarkEstimating() bool {
 	return ok
 }
 
-// createWatermarkEstimatorFn returns the "createWatermarkEstimator" function, if present
+// CreateWatermarkEstimatorFn returns the "createWatermarkEstimator" function, if present
 func (f *SplittableDoFn) CreateWatermarkEstimatorFn() *funcx.Fn {
 	return f.methods[createWatermarkEstimatorName]
 }
@@ -324,6 +330,27 @@ func (f *SplittableDoFn) WatermarkEstimatorT() reflect.Type {
 	return f.CreateWatermarkEstimatorFn().Ret[0].T
 }
 
+// IsStatefulWatermarkEstimating returns whether the DoFn implements custom watermark state.
+func (f *SplittableDoFn) IsStatefulWatermarkEstimating() bool {
+	_, ok := f.methods[watermarkEstimatorStateName]
+	return ok
+}
+
+// InitialWatermarkEstimatorStateFn returns the "InitialWatermarkEstimatorState" function, if present
+func (f *SplittableDoFn) InitialWatermarkEstimatorStateFn() *funcx.Fn {
+	return f.methods[initialWatermarkEstimatorStateName]
+}
+
+// WatermarkEstimatorStateFn returns the "WatermarkEstimatorState" function, if present
+func (f *SplittableDoFn) WatermarkEstimatorStateFn() *funcx.Fn {
+	return f.methods[watermarkEstimatorStateName]
+}
+
+// WatermarkEstimatorStateT returns the type of the watermark estimator state from the SDF
+func (f *SplittableDoFn) WatermarkEstimatorStateT() reflect.Type {
+	return f.WatermarkEstimatorStateFn().Ret[0].T
+}
+
 // TODO(herohde) 5/19/2017: we can sometimes detect whether the main input must be
 // a KV or not based on the other signatures (unless we're more loose about which
 // sideinputs are present). Bind should respect that.
@@ -519,7 +546,7 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
 	}
 
 	if isWatermarkEstimating {
-		err := validateWatermarkSig(fn)
+		err := validateWatermarkSig(fn, int(numMainIn))
 		if err != nil {
 			return nil, addContext(err, fn)
 		}
@@ -852,11 +879,11 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 		method := fn.methods[name]
 		switch name {
 		case createInitialRestrictionName:
-			if err := validateSdfElementT(fn, createInitialRestrictionName, method, num); err != nil {
+			if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, 0); err != nil {
 				return err
 			}
 		case splitRestrictionName:
-			if err := validateSdfElementT(fn, splitRestrictionName, method, num); err != nil {
+			if err := validateSdfElementT(fn, splitRestrictionName, method, num, 0); err != nil {
 				return err
 			}
 			if method.Param[num].T != restrictionT {
@@ -877,7 +904,7 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 					splitRestrictionName, 0, method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, splitRestrictionName)
 			}
 		case restrictionSizeName:
-			if err := validateSdfElementT(fn, restrictionSizeName, method, num); err != nil {
+			if err := validateSdfElementT(fn, restrictionSizeName, method, num, 0); err != nil {
 				return err
 			}
 			if method.Param[num].T != restrictionT {
@@ -928,15 +955,15 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 
 // validateSdfElementT validates that element types in an SDF method are
 // consistent with the ProcessElement method. This method assumes that the
-// first 'num' parameters to the SDF method are the elements.
-func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int) error {
+// first 'num' parameters starting with startIndex are the elements.
+func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int, startIndex int) error {
 	// ProcessElement is the most canonical source of the element type. We can
 	// processFn is valid by this point and skip unnecessary validation.
 	processFn := fn.methods[processElementName]
 	pos, _, _ := processFn.Inputs()
 
 	for i := 0; i < num; i++ {
-		if method.Param[i].T != processFn.Param[pos+i].T {
+		if method.Param[i+startIndex].T != processFn.Param[pos+i].T {
 			err := errors.Errorf("mismatched element type in method %v, param %v. got: %v, want: %v",
 				name, i, method.Param[i].T, processFn.Param[pos+i].T)
 			return errors.SetTopLevelMsgf(err, "Mismatched element type in method %v, "+
@@ -961,45 +988,163 @@ func validateIsWatermarkEstimating(fn *Fn, isSdf bool) (bool, error) {
 }
 
 // validateWatermarkSig validates that all watermark related functions are valid
-func validateWatermarkSig(fn *Fn) error {
-	paramRange := map[string][]int{
-		createWatermarkEstimatorName: []int{0, 0},
-	}
+func validateWatermarkSig(fn *Fn, numMainIn int) error {
 	returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF methods.
 
 	watermarkEstimatorT := reflect.TypeOf((*sdf.WatermarkEstimator)(nil)).Elem()
+	method := fn.methods[createWatermarkEstimatorName]
+
+	if len(method.Param) > 1 {
+		err := errors.Errorf("unexpected number of params in method %v. got: %v, want number in range: 0 to 1",
+			createWatermarkEstimatorName, len(method.Param))
+		return errors.SetTopLevelMsgf(err, "unexpected number of parameters in method %v. "+
+			"got: %v, want number in range: 0 to 1. Check that the signature conforms to the expected signature for %v.",
+			createWatermarkEstimatorName, len(method.Param), createWatermarkEstimatorName)
+	} else if len(method.Param) == 1 {
+		err := validateStatefulWatermarkSig(fn, numMainIn)
+		if err != nil {
+			return err
+		}
+	} else {
+		if _, ok := fn.methods[initialWatermarkEstimatorStateName]; ok {
+			err := errors.Errorf("stateful watermark estimation method %v is present, "+
+				"but CreateWatermarkEstimator doesn't take in a state parameter.", initialWatermarkEstimatorStateName)
+			return err
+		}
+		if _, ok := fn.methods[watermarkEstimatorStateName]; ok {
+			err := errors.Errorf("stateful watermark estimation method %v is present, "+
+				"but CreateWatermarkEstimator doesn't take in a state parameter.", watermarkEstimatorStateName)
+			return err
+		}
+	}
+
+	if len(method.Ret) != returnNum {
+		err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v",
+			createWatermarkEstimatorName, len(method.Ret), returnNum)
+		return errors.SetTopLevelMsgf(err, "unexpected number of return values in method %v. "+
+			"got: %v, want: %v. Check that the signature conforms to the expected signature for %v.",
+			createWatermarkEstimatorName, len(method.Ret), returnNum, createWatermarkEstimatorName)
+	} else if !method.Ret[0].T.Implements(watermarkEstimatorT) {
+		err := errors.Errorf("invalid output type in method %v, return %v: %v does not implement sdf.WatermarkEstimator",
+			createWatermarkEstimatorName, 0, method.Ret[0].T)
+		return errors.SetTopLevelMsgf(err, "invalid output type in method %v, "+
+			"return value at index %v (type: %v). Output of method %v must implement sdf.WatermarkEstimator.",
+			createWatermarkEstimatorName, 0, method.Ret[0].T, createWatermarkEstimatorName)
+	}
 
+	return nil
+}
+
+func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
+	// Store missing method names so we can output them to the user if validation fails.
+	var missing []string
 	for _, name := range watermarkEstimationNames {
-		if method, ok := fn.methods[name]; ok {
-			if len(method.Param) < paramRange[name][0] || len(method.Param) > paramRange[name][1] {
-				err := errors.Errorf("unexpected number of params in method %v. got: %v, want number in range: %v to %v",
-					name, len(method.Param), paramRange[name][0], paramRange[name][1])
-				return errors.SetTopLevelMsgf(err, "Unexpected number of parameters in method %v. "+
-					"Got: %v, Want number in range: %v to %v. Check that the signature conforms to the expected signature for %v, "+
+		_, ok := fn.methods[name]
+		if !ok {
+			missing = append(missing, name)
+		}
+	}
+	if len(missing) > 0 {
+		err := errors.Errorf("not all required stateful watermark estimation methods are present, "+
+			"but CreateWatermarkEstimator takes in a state parameter. Missing methods: %v", missing)
+		return err
+	}
+
+	restT := fn.methods[createInitialRestrictionName].Ret[0].T
+	watermarkStateT := fn.methods[createWatermarkEstimatorName].Param[0].T
+	watermarkEstimatorT := fn.methods[createWatermarkEstimatorName].Ret[0].T
+
+	// If number of main inputs is ambiguous, we check for consistency against
+	// CreateInitialRestriction.
+	if numMainIn == int(MainUnknown) {
+		initialRestFn := fn.methods[createInitialRestrictionName]
+		paramNum := len(initialRestFn.Param)
+		switch paramNum {
+		case int(MainSingle), int(MainKv):
+			numMainIn = paramNum
+		}
+	}
+
+	for _, name := range watermarkEstimationNames {
+		method := fn.methods[name]
+		switch name {
+		case initialWatermarkEstimatorStateName:
+			if len(method.Param) != numMainIn+2 {
+				err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v",
+					initialWatermarkEstimatorStateName, len(method.Param), numMainIn+2)
+				return errors.SetTopLevelMsgf(err, "unexpected number of parameters in method %v. "+
+					"got: %v, want: %v. Check that the signature conforms to the expected signature for %v, "+
 					"and that elements in SDF method parameters match elements in %v.",
-					name, len(method.Param), paramRange[name][0], paramRange[name][1], name, processElementName)
+					initialWatermarkEstimatorStateName, len(method.Param), numMainIn+2, initialWatermarkEstimatorStateName, processElementName)
+			}
+			if method.Param[0].T != typex.EventTimeType {
+				err := errors.Errorf("unexpected parameter type in method %v, param %v. got: %v, want: %v",
+					initialWatermarkEstimatorStateName, 0, method.Param[0].T, typex.EventTimeType)
+				return errors.SetTopLevelMsgf(err, "mismatched event time type in method %v, "+
+					"parameter at index %v. got: %v, want: %v.",
+					initialWatermarkEstimatorStateName, 0, method.Param[0].T, typex.EventTimeType)
 			}
-			if len(method.Ret) != returnNum {
-				err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v",
-					name, len(method.Ret), returnNum)
-				return errors.SetTopLevelMsgf(err, "Unexpected number of return values in method %v. "+
-					"Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v.",
-					name, len(method.Ret), returnNum, name)
+			if method.Param[1].T != restT {
+				err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v",
+					initialWatermarkEstimatorStateName, 1, method.Param[1].T, restT)
+				return errors.SetTopLevelMsgf(err, "mismatched restriction type in method %v, "+
+					"parameter at index %v. got: %v, want: %v (from method %v). "+
+					"Ensure that all restrictions in an SDF are the same type.",
+					initialWatermarkEstimatorStateName, 1, method.Param[1].T, restT, createTrackerName)
+			}
+			if err := validateSdfElementT(fn, restrictionSizeName, method, numMainIn, 2); err != nil {
+				return err
 			}
 
-			switch name {
-			case createWatermarkEstimatorName:
-				if !method.Ret[0].T.Implements(watermarkEstimatorT) {
-					err := errors.Errorf("invalid output type in method %v, return %v: %v does not implement sdf.WatermarkEstimator",
-						createWatermarkEstimatorName, 0, method.Ret[0].T)
-					return errors.SetTopLevelMsgf(err, "Invalid output type in method %v, "+
-						"return value at index %v (type: %v). Output of method %v must implement sdf.WatermarkEstimator.",
-						createWatermarkEstimatorName, 0, method.Ret[0].T, createWatermarkEstimatorName)
-				}
+			if len(method.Ret) != 1 {
+				err := errors.Errorf("unexpected number of elements returned in method %v. got: %v, want %v",
+					initialWatermarkEstimatorStateName, len(method.Ret), 1)
+				return errors.SetTopLevelMsgf(err, "unexpected number of elements returned in method %v. "+
+					"got: %v, want %v. Check that the signature conforms to the expected signature for %v.",
+					initialWatermarkEstimatorStateName, len(method.Ret), 1, initialWatermarkEstimatorStateName)
+			}
+			if method.Ret[0].T != watermarkStateT {
+				err := errors.Errorf("mismatched output type in method %v, return %v. got: %v, want: %v",
+					createWatermarkEstimatorName, 0, method.Ret[0].T, watermarkStateT)
+				return errors.SetTopLevelMsgf(err, "mismatched output type in method %v, "+
+					"return value at index %v got: %v, want: %v (from method %v). "+
+					"Ensure that all watermark states in an SDF are the same type.",
+					createWatermarkEstimatorName, 0, method.Ret[0].T, watermarkStateT, createWatermarkEstimatorName)
+			}
+		case watermarkEstimatorStateName:
+			if len(method.Param) != 1 {
+				err := errors.Errorf("unexpected number of params in method %v. got: %v, want %v",
+					watermarkEstimatorStateName, len(method.Param), 1)
+				return errors.SetTopLevelMsgf(err, "unexpected number of parameters in method %v. "+
+					"got: %v, want %v. Check that the signature conforms to the expected signature for %v, "+
+					"and that elements in SDF method parameters match elements in %v.",
+					watermarkEstimatorStateName, len(method.Param), 1, watermarkEstimatorStateName, processElementName)
+			}
+			if method.Param[0].T != watermarkEstimatorT {
+				err := errors.Errorf("mismatched watermark state type in method %v, return %v. got: %v, want: %v",
+					watermarkEstimatorStateName, 0, method.Param[0].T, watermarkEstimatorT)
+				return errors.SetTopLevelMsgf(err, "mismatched watermark state type in method %v, "+
+					"return value at index %v got: %v, want: %v (from method %v). "+
+					"Ensure that all watermark states in an SDF are the same type.",
+					watermarkEstimatorStateName, 0, method.Param[0].T, watermarkEstimatorT, watermarkEstimatorStateName)
+			}
+			if len(method.Ret) != 1 {
+				err := errors.Errorf("unexpected number of elements returned in method %v. got: %v, want %v",
+					watermarkEstimatorStateName, len(method.Ret), 1)
+				return errors.SetTopLevelMsgf(err, "unexpected number of elements returned in method %v. "+
+					"got: %v, want %v. Check that the signature conforms to the expected signature for %v.",
+					watermarkEstimatorStateName, len(method.Ret), 1, watermarkEstimatorStateName)
+			}
+			if method.Ret[0].T != watermarkStateT {
+				err := errors.Errorf("mismatched output type in method %v, return %v. got: %v, want: %v",
+					watermarkEstimatorStateName, 0, method.Ret[0].T, watermarkStateT)
+				return errors.SetTopLevelMsgf(err, "mismatched output type in method %v, "+
+					"return value at index %v got: %v, want: %v (from method %v). "+
+					"Ensure that all watermark estimators in an SDF are the same type.",
+					watermarkEstimatorStateName, 0, method.Ret[0].T, watermarkStateT, watermarkEstimatorStateName)
 			}
 		}
 	}
-
 	return nil
 }
 
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index e38c9a34af6..c04d2d07529 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -235,6 +235,9 @@ func TestNewDoFnWatermarkEstimating(t *testing.T) {
 			main mainInputs
 		}{
 			{dfn: &GoodWatermarkEstimating{}, main: MainSingle},
+			{dfn: &GoodWatermarkEstimatingKv{}, main: MainKv},
+			{dfn: &GoodStatefulWatermarkEstimating{}, main: MainSingle},
+			{dfn: &GoodStatefulWatermarkEstimatingKv{}, main: MainKv},
 		}
 
 		for _, test := range tests {
@@ -255,6 +258,17 @@ func TestNewDoFnWatermarkEstimating(t *testing.T) {
 		}{
 			{dfn: &BadWatermarkEstimatingNonSdf{}},
 			{dfn: &BadWatermarkEstimatingCreateWatermarkEstimatorReturnType{}},
+			{dfn: &BadStatefulWatermarkEstimatingInconsistentState{}},
+			{dfn: &BadStatefulWatermarkEstimatingInconsistentEstimator{}},
+			{dfn: &BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateParams{}},
+			{dfn: &BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoParams{}},
+			{dfn: &BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateReturns{}},
+			{dfn: &BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoReturns{}},
+			{dfn: &BadStatefulWatermarkEstimatingWrongPositionalParameter0{}},
+			{dfn: &BadStatefulWatermarkEstimatingWrongPositionalParameter1{}},
+			{dfn: &BadStatefulWatermarkEstimatingWrongPositionalParameter2{}},
+			{dfn: &BadStatefulKvWatermarkEstimatingWrongPositionalParameter2{}},
+			{dfn: &BadStatefulWatermarkEstimatingWrongReturn{}},
 		}
 		for _, test := range tests {
 			t.Run(reflect.TypeOf(test.dfn).String(), func(t *testing.T) {
@@ -693,6 +707,27 @@ func (rt *RTrackerT) GetRestriction() interface{} {
 	return nil
 }
 
+type RTracker2T struct{}
+
+func (rt *RTracker2T) TryClaim(interface{}) bool {
+	return false
+}
+func (rt *RTracker2T) GetError() error {
+	return nil
+}
+func (rt *RTracker2T) TrySplit(fraction float64) (interface{}, interface{}, error) {
+	return nil, nil, nil
+}
+func (rt *RTracker2T) GetProgress() (float64, float64) {
+	return 0, 0
+}
+func (rt *RTracker2T) IsDone() bool {
+	return true
+}
+func (rt *RTracker2T) GetRestriction() interface{} {
+	return nil
+}
+
 type GoodSdf struct {
 	*GoodDoFn
 }
@@ -747,34 +782,64 @@ func (e WatermarkEstimatorT) CurrentWatermark() time.Time {
 	return time.Now()
 }
 
+type WatermarkEstimator2T struct{}
+
+func (e WatermarkEstimator2T) CurrentWatermark() time.Time {
+	return time.Now()
+}
+
+func (e WatermarkEstimator2T) CurrentWatermark2() time.Time {
+	return time.Now()
+}
+
 type GoodWatermarkEstimating struct {
-	*GoodDoFn
+	*GoodSdf
 }
 
-func (fn *GoodWatermarkEstimating) CreateInitialRestriction(int) RestT {
-	return RestT{}
+func (fn *GoodWatermarkEstimating) CreateWatermarkEstimator() WatermarkEstimatorT {
+	return WatermarkEstimatorT{}
 }
 
-func (fn *GoodWatermarkEstimating) SplitRestriction(int, RestT) []RestT {
-	return []RestT{}
+type GoodWatermarkEstimatingKv struct {
+	*GoodSdfKv
+}
+
+func (fn *GoodWatermarkEstimatingKv) CreateWatermarkEstimator() WatermarkEstimatorT {
+	return WatermarkEstimatorT{}
+}
+
+type GoodStatefulWatermarkEstimating struct {
+	*GoodSdf
 }
 
-func (fn *GoodWatermarkEstimating) RestrictionSize(int, RestT) float64 {
+func (fn *GoodStatefulWatermarkEstimating) InitialWatermarkEstimatorState(ts typex.EventTime, rt RestT, element int) int {
 	return 0
 }
 
-func (fn *GoodWatermarkEstimating) CreateTracker(RestT) *RTrackerT {
-	return &RTrackerT{}
+func (fn *GoodStatefulWatermarkEstimating) CreateWatermarkEstimator(state int) WatermarkEstimatorT {
+	return WatermarkEstimatorT{}
 }
 
-func (fn *GoodWatermarkEstimating) ProcessElement(*RTrackerT, int) int {
+func (fn *GoodStatefulWatermarkEstimating) WatermarkEstimatorState(estimator WatermarkEstimatorT) int {
 	return 0
 }
 
-func (fn *GoodWatermarkEstimating) CreateWatermarkEstimator() WatermarkEstimatorT {
+type GoodStatefulWatermarkEstimatingKv struct {
+	*GoodSdfKv
+}
+
+func (fn *GoodStatefulWatermarkEstimatingKv) InitialWatermarkEstimatorState(ts typex.EventTime, rt RestT, k int, v int) int {
+	return 0
+}
+
+func (fn *GoodStatefulWatermarkEstimatingKv) CreateWatermarkEstimator(state int) WatermarkEstimatorT {
 	return WatermarkEstimatorT{}
 }
 
+func (fn *GoodStatefulWatermarkEstimatingKv) WatermarkEstimatorState(estimator WatermarkEstimatorT) int {
+	return 0
+}
+
 // Examples of incorrect SDF signatures.
 // Examples with missing methods.
 
@@ -972,6 +1037,93 @@ func (fn *BadWatermarkEstimatingCreateWatermarkEstimatorReturnType) CreateWaterm
 	return 5
 }
 
+type BadStatefulWatermarkEstimatingInconsistentState struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingInconsistentState) WatermarkEstimatorState(estimator WatermarkEstimatorT) string {
+	return ""
+}
+
+type BadStatefulWatermarkEstimatingInconsistentEstimator struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingInconsistentEstimator) WatermarkEstimatorState(estimator WatermarkEstimator2T) int {
+	return 0
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateParams struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateParams) WatermarkEstimatorState(estimator WatermarkEstimatorT, element int) int {
+	return 0
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoParams struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoParams) WatermarkEstimatorState() int {
+	return 0
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateReturns struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateReturns) WatermarkEstimatorState(estimator WatermarkEstimatorT) (int, error) {
+	return 0, nil
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoReturns struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoReturns) WatermarkEstimatorState(estimator WatermarkEstimatorT) {
+}
+
+type BadStatefulWatermarkEstimatingWrongPositionalParameter0 struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongPositionalParameter0) InitialWatermarkEstimatorState(a int, rt *RTrackerT, element int) int {
+	return 0
+}
+
+type BadStatefulWatermarkEstimatingWrongPositionalParameter1 struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongPositionalParameter1) InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTracker2T, element int) int {
+	return 0
+}
+
+type BadStatefulWatermarkEstimatingWrongPositionalParameter2 struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongPositionalParameter2) InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTrackerT, element string) int {
+	return 0
+}
+
+type BadStatefulKvWatermarkEstimatingWrongPositionalParameter2 struct {
+	*GoodStatefulWatermarkEstimatingKv
+}
+
+func (fn *BadStatefulKvWatermarkEstimatingWrongPositionalParameter2) InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTrackerT, element int) int {
+	return 0
+}
+
+type BadStatefulWatermarkEstimatingWrongReturn struct {
+	*GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongReturn) InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTrackerT, element int) string {
+	return ""
+}
+
 // Examples of correct CombineFn signatures
 
 type MyAccum struct{}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go
index 64355a84c1f..bd2835fe2ff 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go
@@ -125,7 +125,7 @@ func TestDynamicSplit(t *testing.T) {
 			if err := procRes; err != nil {
 				t.Fatal(err)
 			}
-			pRest := p.Elm.(*FullValue).Elm2.(offsetrange.Restriction)
+			pRest := p.Elm.(*FullValue).Elm2.(*FullValue).Elm.(offsetrange.Restriction)
 			if got, want := len(out.Elements), int(pRest.End-pRest.Start); got != want {
 				t.Errorf("Unexpected number of elements: got: %v, want: %v", got, want)
 			}
@@ -226,8 +226,11 @@ func claimBlockingDriver(plan *Plan, dc DataContext, sdf *splitTestSdf) (splitRe
 func createElm() *FullValue {
 	return &FullValue{
 		Elm: &FullValue{
-			Elm:  20,
-			Elm2: offsetrange.Restriction{Start: 0, End: 20},
+			Elm: 20,
+			Elm2: &FullValue{
+				Elm:  offsetrange.Restriction{Start: 0, End: 20},
+				Elm2: false,
+			},
 		},
 		Elm2: float64(20),
 	}
@@ -244,7 +247,10 @@ func createSplitTestInCoder() *coder.Coder {
 		coder.NewKV([]*coder.Coder{
 			coder.NewKV([]*coder.Coder{
 				intCoder(reflectx.Int),
-				{Kind: coder.Custom, T: typex.New(restT), Custom: restCdr},
+				coder.NewKV([]*coder.Coder{
+					{Kind: coder.Custom, T: typex.New(restT), Custom: restCdr},
+					coder.NewBool(),
+				}),
 			}),
 			coder.NewDouble(),
 		}),
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index cd1234c29e3..ec457213e86 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -21,6 +21,7 @@ import (
 	"math"
 	"path"
 
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
@@ -37,7 +38,8 @@ type PairWithRestriction struct {
 	Fn  *graph.DoFn
 	Out Node
 
-	inv *cirInvoker
+	inv     *cirInvoker
+	iwesInv *iwesInvoker
 }
 
 // ID returns the UnitID for this unit.
@@ -52,6 +54,13 @@ func (n *PairWithRestriction) Up(_ context.Context) error {
 	if n.inv, err = newCreateInitialRestrictionInvoker(fn); err != nil {
 		return errors.WithContextf(err, "%v", n)
 	}
+	var giwesFn *funcx.Fn
+	if (*graph.SplittableDoFn)(n.Fn).IsStatefulWatermarkEstimating() {
+		giwesFn = (*graph.SplittableDoFn)(n.Fn).InitialWatermarkEstimatorStateFn()
+	}
+	if n.iwesInv, err = newInitialWatermarkEstimatorStateInvoker(giwesFn); err != nil {
+		return errors.WithContextf(err, "%v", n)
+	}
 	return nil
 }
 
@@ -73,13 +82,16 @@ func (n *PairWithRestriction) StartBundle(ctx context.Context, id string, data D
 //
 //   *FullValue {
 //     Elm: *FullValue (original input)
-//     Elm2: Restriction
+//     Elm2: *FullValue {
+//       Elm: Restriction
+//       Elm2: Watermark estimator state
+//     }
 //     Windows
 //     Timestamps
 //   }
 func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
 	rest := n.inv.Invoke(elm)
-	output := FullValue{Elm: elm, Elm2: rest, Timestamp: elm.Timestamp, Windows: elm.Windows}
+	output := FullValue{Elm: elm, Elm2: &FullValue{Elm: rest, Elm2: n.iwesInv.Invoke(rest, elm)}, Timestamp: elm.Timestamp, Windows: elm.Windows}
 
 	return n.Out.ProcessElement(ctx, &output, values...)
 }
@@ -87,6 +99,7 @@ func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm *FullValue
 // FinishBundle resets the invokers.
 func (n *PairWithRestriction) FinishBundle(ctx context.Context) error {
 	n.inv.Reset()
+	n.iwesInv.Reset()
 	return n.Out.FinishBundle(ctx)
 }
 
@@ -147,13 +160,16 @@ func (n *SplitAndSizeRestrictions) StartBundle(ctx context.Context, id string, d
 //
 //   *FullValue {
 //     Elm: *FullValue (original input)
-//     Elm2: Restriction
+//     Elm2: *FullValue {
+//       Elm: Restriction
+//       Elm2: Watermark estimator state
+//     }
 //     Windows
 //     Timestamps
 //   }
 //
 // ProcessElement splits the given restriction into one or more restrictions and
-// then sizes each. The outputs are in the structure <<elem, restriction>, size>
+// then sizes each. The outputs are in the structure <<elem, <restriction, watermark estimator state>>, size>
 // where elem is the original main input to the unexpanded SDF. Windows and
 // Timestamps are copied to each split output.
 //
@@ -162,14 +178,18 @@ func (n *SplitAndSizeRestrictions) StartBundle(ctx context.Context, id string, d
 //   *FullValue {
 //     Elm: *FullValue {
 //       Elm:  *FullValue (original input)
-//       Elm2: Restriction
+//       Elm2: *FullValue {
+// 		   Elm: Restriction
+//         Elm2: Watermark estimator state
+//		 }
 //     }
 //     Elm2: float64 (size)
 //     Windows
 //     Timestamps
 //   }
 func (n *SplitAndSizeRestrictions) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
-	rest := elm.Elm2
+	rest := elm.Elm2.(*FullValue).Elm
+	ws := elm.Elm2.(*FullValue).Elm2
 	mainElm := elm.Elm.(*FullValue)
 
 	splitRests := n.splitInv.Invoke(mainElm, rest)
@@ -184,7 +204,7 @@ func (n *SplitAndSizeRestrictions) ProcessElement(ctx context.Context, elm *Full
 
 		output.Timestamp = elm.Timestamp
 		output.Windows = elm.Windows
-		output.Elm = &FullValue{Elm: mainElm, Elm2: splitRest}
+		output.Elm = &FullValue{Elm: mainElm, Elm2: &FullValue{Elm: splitRest, Elm2: ws}}
 		output.Elm2 = size
 
 		if err := n.Out.ProcessElement(ctx, output, values...); err != nil {
@@ -223,6 +243,7 @@ type ProcessSizedElementsAndRestrictions struct {
 	ctInv   *ctInvoker
 	sizeInv *rsInvoker
 	cweInv  *cweInvoker
+	wesInv  *wesInvoker
 
 	// SU is a buffered channel for indicating when this unit is splittable.
 	// When this unit is processing an element, it sends a SplittableUnit
@@ -242,9 +263,10 @@ type ProcessSizedElementsAndRestrictions struct {
 	// from a DoFn for use in splitting the bundle if the process should be resumed.
 	continuation sdf.ProcessContinuation
 
-	elm   *FullValue   // Currently processing element.
-	rt    sdf.RTracker // Currently processing element's restriction tracker.
-	currW int          // Index of the current window in elm being processed.
+	elm     *FullValue   // Currently processing element.
+	rt      sdf.RTracker // Currently processing element's restriction tracker.
+	currW   int          // Index of the current window in elm being processed.
+	initWeS interface{}  // Initial state of the watermark estimator before processing elements.
 
 	// Number of windows being processed. This number can differ from the number
 	// of windows in an element, indicating to only process a subset of windows.
@@ -278,6 +300,13 @@ func (n *ProcessSizedElementsAndRestrictions) Up(ctx context.Context) error {
 			return errors.WithContextf(err, "%v", n)
 		}
 	}
+	var gwesFn *funcx.Fn
+	if (*graph.SplittableDoFn)(n.PDo.Fn).IsStatefulWatermarkEstimating() {
+		gwesFn = (*graph.SplittableDoFn)(n.PDo.Fn).WatermarkEstimatorStateFn()
+	}
+	if n.wesInv, err = newWatermarkEstimatorStateInvoker(gwesFn); err != nil {
+		return errors.WithContextf(err, "%v", n)
+	}
 	n.SU = make(chan SplittableUnit, 1)
 	return n.PDo.Up(ctx)
 }
@@ -288,7 +317,7 @@ func (n *ProcessSizedElementsAndRestrictions) StartBundle(ctx context.Context, i
 }
 
 // ProcessElement expects the same structure as the output of
-// SplitAndSizeRestrictions, approximately <<elem, restriction>, size>. The
+// SplitAndSizeRestrictions, approximately <<elem, <restriction,watermark estimator state>>, size>. The
 // only difference is that if the input was decoded in between the two steps,
 // then single-element inputs were lifted from the *FullValue they were
 // stored in.
@@ -298,7 +327,10 @@ func (n *ProcessSizedElementsAndRestrictions) StartBundle(ctx context.Context, i
 //   *FullValue {
 //     Elm: *FullValue {
 //       Elm:  *FullValue (KV input) or InputType (single-element input)
-//       Elm2: Restriction
+//		 Elm2: *FullValue {
+// 		   Elm: Restriction
+//         Elm2: Watermark estimator state
+//		 }
 //     }
 //     Elm2: float64 (size)
 //     Windows
@@ -343,15 +375,16 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
 	}
 
 	if n.cweInv != nil {
-		n.PDo.we = n.cweInv.Invoke()
+		n.PDo.we = n.cweInv.Invoke(elm.Elm.(*FullValue).Elm2.(*FullValue).Elm2)
 	}
+	n.initWeS = n.wesInv.Invoke(n.PDo.we)
 
 	// Begin processing elements, exploding windows if necessary.
 	n.currW = 0
 	if !mustExplodeWindows(n.PDo.inv.fn, elm, len(n.PDo.Side) > 0) {
 		// If windows don't need to be exploded (i.e. aren't observed), treat
 		// all windows as one as an optimization.
-		rest := elm.Elm.(*FullValue).Elm2
+		rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
 		rt := n.ctInv.Invoke(rest)
 		mainIn.RTracker = rt
 
@@ -373,7 +406,7 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context,
 		n.numW = len(elm.Windows)
 
 		for i := 0; i < n.numW; i++ {
-			rest := elm.Elm.(*FullValue).Elm2
+			rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
 			rt := n.ctInv.Invoke(rest)
 			key := &mainIn.Key
 			w := elm.Windows[i]
@@ -402,6 +435,7 @@ func (n *ProcessSizedElementsAndRestrictions) FinishBundle(ctx context.Context)
 	if n.cweInv != nil {
 		n.cweInv.Reset()
 	}
+	n.wesInv.Reset()
 	return n.PDo.FinishBundle(ctx)
 }
 
@@ -457,6 +491,17 @@ type SplittableUnit interface {
 // each case occurs and the implementation details, see the documentation for
 // the singleWindowSplit and multiWindowSplit methods.
 func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []*FullValue, error) {
+	// Get the watermark state immediately so that we don't overestimate our current watermark.
+	var pWeState interface{}
+	var rWeState interface{}
+	rWeState = n.wesInv.Invoke(n.PDo.we)
+	pWeState = rWeState
+	// If we've processed elements, the initial watermark estimator state will be set.
+	// In that case we should hold the output watermark at that initial state so that we don't
+	// Advance past where the current elements are holding the watermark
+	if n.initWeS != nil {
+		pWeState = n.initWeS
+	}
 	addContext := func(err error) error {
 		return errors.WithContext(err, "Attempting split in ProcessSizedElementsAndRestrictions")
 	}
@@ -472,7 +517,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
 	// Split behavior differs depending on whether this is a window-observing
 	// DoFn or not.
 	if len(n.elm.Windows) > 1 {
-		p, r, err := n.multiWindowSplit(f)
+		p, r, err := n.multiWindowSplit(f, pWeState, rWeState)
 		if err != nil {
 			return nil, nil, addContext(err)
 		}
@@ -480,7 +525,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
 	}
 
 	// Not window-observing, or window-observing but only one window.
-	p, r, err := n.singleWindowSplit(f)
+	p, r, err := n.singleWindowSplit(f, pWeState, rWeState)
 	if err != nil {
 		return nil, nil, addContext(err)
 	}
@@ -492,7 +537,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []
 // behavior is identical). A single restriction split will occur and all windows
 // present in the unsplit element will be present in both the resulting primary
 // and residual.
-func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeState, rWeState interface{}) ([]*FullValue, []*FullValue, error) {
 	if n.rt.IsDone() { // Not an error, but not splittable.
 		return []*FullValue{}, []*FullValue{}, nil
 	}
@@ -505,11 +550,11 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64) ([]*F
 		return []*FullValue{}, []*FullValue{}, nil
 	}
 
-	pfv, err := n.newSplitResult(p, n.elm.Windows)
+	pfv, err := n.newSplitResult(p, n.elm.Windows, pWeState)
 	if err != nil {
 		return nil, nil, err
 	}
-	rfv, err := n.newSplitResult(r, n.elm.Windows)
+	rfv, err := n.newSplitResult(r, n.elm.Windows, rWeState)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -540,7 +585,7 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64) ([]*F
 //
 // This method also updates the current number of windows (n.numW) so that
 // windows in the residual will no longer be processed.
-func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeState interface{}, rWeState interface{}) ([]*FullValue, []*FullValue, error) {
 	// Get the split point in window range, to see what window it falls in.
 	done, rem := n.rt.GetProgress()
 	cwp := done / (done + rem)                      // Progress in current window.
@@ -553,25 +598,25 @@ func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64) ([]*Fu
 		if n.rt.IsDone() {
 			// Current RTracker is done so we can't split within the window, so
 			// split at window boundary instead.
-			return n.windowBoundarySplit(n.currW + 1)
+			return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
 		}
 
 		// Get the fraction of remaining work in the current window to split at.
 		cwsp := wsp - float64(n.currW) // Split point in current window.
 		rf := (cwsp - cwp) / (1 - cwp) // Fraction of work in RTracker to split at.
 
-		return n.currentWindowSplit(rf)
+		return n.currentWindowSplit(rf, pWeState, rWeState)
 	} else {
 		// Split at nearest window boundary to split point.
 		wb := math.Round(wsp)
-		return n.windowBoundarySplit(int(wb))
+		return n.windowBoundarySplit(int(wb), pWeState, rWeState)
 	}
 }
 
 // currentWindowSplit performs an appropriate split at the given fraction of
 // remaining work in the current window. Also updates numW to stop after the
 // current window.
-func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeState interface{}, rWeState interface{}) ([]*FullValue, []*FullValue, error) {
 	p, r, err := n.rt.TrySplit(f)
 	if err != nil {
 		return nil, nil, err
@@ -579,33 +624,33 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64) ([]*
 	if r == nil {
 		// If r is nil then the split failed/returned an empty residual, but
 		// we can still split at a window boundary.
-		return n.windowBoundarySplit(n.currW + 1)
+		return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
 	}
 
 	// Split of currently processing restriction in a single window.
 	ps := make([]*FullValue, 1)
-	newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1])
+	newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1], pWeState)
 	if err != nil {
 		return nil, nil, err
 	}
 	ps[0] = newP
 	rs := make([]*FullValue, 1)
-	newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1])
+	newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1], rWeState)
 	if err != nil {
 		return nil, nil, err
 	}
 	rs[0] = newR
 	// Window boundary split surrounding the split restriction above.
-	full := n.elm.Elm.(*FullValue).Elm2
+	full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
 	if 0 < n.currW {
-		newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW])
+		newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW], pWeState)
 		if err != nil {
 			return nil, nil, err
 		}
 		ps = append(ps, newP)
 	}
 	if n.currW+1 < n.numW {
-		newR, err := n.newSplitResult(full, n.elm.Windows[n.currW+1:n.numW])
+		newR, err := n.newSplitResult(full, n.elm.Windows[n.currW+1:n.numW], rWeState)
 		if err != nil {
 			return nil, nil, err
 		}
@@ -618,17 +663,17 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64) ([]*
 // windowBoundarySplit performs an appropriate split at a window boundary. The
 // split point taken should be the index of the first window in the residual.
 // Also updates numW to stop at the split point.
-func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, pWeState interface{}, rWeState interface{}) ([]*FullValue, []*FullValue, error) {
 	// If this is at the boundary of the last window, split is a no-op.
 	if splitPt == n.numW {
 		return []*FullValue{}, []*FullValue{}, nil
 	}
-	full := n.elm.Elm.(*FullValue).Elm2
-	pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt])
+	full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
+	pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt], pWeState)
 	if err != nil {
 		return nil, nil, err
 	}
-	rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW])
+	rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW], rWeState)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -640,7 +685,7 @@ func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int) (
 // element restriction pair based on the currently processing element, but with
 // a modified restriction and windows. Intended for creating primaries and
 // residuals to return as split results.
-func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest interface{}, w []typex.Window) (*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest interface{}, w []typex.Window, weState interface{}) (*FullValue, error) {
 	var size float64
 	elm := n.elm.Elm.(*FullValue).Elm
 	if fv, ok := elm.(*FullValue); ok {
@@ -659,8 +704,11 @@ func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest interface{}, w
 	}
 	return &FullValue{
 		Elm: &FullValue{
-			Elm:  elm,
-			Elm2: rest,
+			Elm: elm,
+			Elm2: &FullValue{
+				Elm:  rest,
+				Elm2: weState,
+			},
 		},
 		Elm2:      size,
 		Timestamp: n.elm.Timestamp,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
index 79eeda23afc..f90cdd0abf5 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
@@ -305,7 +305,7 @@ func (n *ctInvoker) Reset() {
 type cweInvoker struct {
 	fn   *funcx.Fn
 	args []interface{} // Cache to avoid allocating new slices per-element.
-	call func() sdf.WatermarkEstimator
+	call func(rest interface{}) sdf.WatermarkEstimator
 }
 
 func newCreateWatermarkEstimatorInvoker(fn *funcx.Fn) (*cweInvoker, error) {
@@ -321,27 +321,38 @@ func newCreateWatermarkEstimatorInvoker(fn *funcx.Fn) (*cweInvoker, error) {
 
 func (n *cweInvoker) initCallFn() error {
 	// Expects a signature of the form:
-	// () sdf.WatermarkEstimator
+	// (watermarkState?) sdf.WatermarkEstimator
 	switch fnT := n.fn.Fn.(type) {
 	case reflectx.Func0x1:
-		n.call = func() sdf.WatermarkEstimator {
+		n.call = func(rest interface{}) sdf.WatermarkEstimator {
 			return fnT.Call0x1().(sdf.WatermarkEstimator)
 		}
+	case reflectx.Func1x1:
+		n.call = func(rest interface{}) sdf.WatermarkEstimator {
+			return fnT.Call1x1(rest).(sdf.WatermarkEstimator)
+		}
 	default:
-		if len(n.fn.Param) != 0 {
+		switch len(n.fn.Param) {
+		case 0:
+			n.call = func(rest interface{}) sdf.WatermarkEstimator {
+				return n.fn.Fn.Call(n.args)[0].(sdf.WatermarkEstimator)
+			}
+		case 1:
+			n.call = func(rest interface{}) sdf.WatermarkEstimator {
+				n.args[0] = rest
+				return n.fn.Fn.Call(n.args)[0].(sdf.WatermarkEstimator)
+			}
+		default:
 			return errors.Errorf("CreateWatermarkEstimator fn %v has unexpected number of parameters: %v",
 				n.fn.Fn.Name(), len(n.fn.Param))
 		}
-		n.call = func() sdf.WatermarkEstimator {
-			return n.fn.Fn.Call(n.args)[0].(sdf.WatermarkEstimator)
-		}
 	}
 	return nil
 }
 
 // Invoke calls CreateWatermarkEstimator given a restriction and returns an sdf.WatermarkEstimator.
-func (n *cweInvoker) Invoke() sdf.WatermarkEstimator {
-	return n.call()
+func (n *cweInvoker) Invoke(rest interface{}) sdf.WatermarkEstimator {
+	return n.call(rest)
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -351,3 +362,147 @@ func (n *cweInvoker) Reset() {
 		n.args[i] = nil
 	}
 }
+
+// iwesInvoker is an invoker for InitialWatermarkEstimatorState.
+type iwesInvoker struct {
+	fn   *funcx.Fn
+	args []interface{} // Cache to avoid allocating new slices per-element.
+	call func(rest interface{}, elms *FullValue) interface{}
+}
+
+func newInitialWatermarkEstimatorStateInvoker(fn *funcx.Fn) (*iwesInvoker, error) {
+	args := []interface{}{}
+	if fn != nil {
+		args = make([]interface{}, len(fn.Param))
+	}
+	n := &iwesInvoker{
+		fn:   fn,
+		args: args,
+	}
+	if err := n.initCallFn(); err != nil {
+		return nil, errors.WithContext(err, "sdf InitialWatermarkEstimatorState invoker")
+	}
+	return n, nil
+}
+
+func (n *iwesInvoker) initCallFn() error {
+	// If no WatermarkEstimatorState function is defined, we'll use a default implementation that just returns false as the state.
+	if n.fn == nil {
+		n.call = func(rest interface{}, elms *FullValue) interface{} {
+			return false
+		}
+		return nil
+	}
+	// Expects a signature of the form:
+	// (typex.EventTime, restrictionTracker, key?, value) interface{}
+	switch fnT := n.fn.Fn.(type) {
+	case reflectx.Func3x1:
+		n.call = func(rest interface{}, elms *FullValue) interface{} {
+			return fnT.Call3x1(elms.Timestamp, rest, elms.Elm)
+		}
+	case reflectx.Func4x1:
+		n.call = func(rest interface{}, elms *FullValue) interface{} {
+			return fnT.Call4x1(elms.Timestamp, rest, elms.Elm, elms.Elm2)
+		}
+	default:
+		switch len(n.fn.Param) {
+		case 3:
+			n.call = func(rest interface{}, elms *FullValue) interface{} {
+				n.args[0] = elms.Timestamp
+				n.args[1] = rest
+				n.args[2] = elms.Elm
+				return n.fn.Fn.Call(n.args)[0]
+			}
+		case 4:
+			n.call = func(rest interface{}, elms *FullValue) interface{} {
+				n.args[0] = elms.Timestamp
+				n.args[1] = rest
+				n.args[2] = elms.Elm
+				n.args[3] = elms.Elm2
+				return n.fn.Fn.Call(n.args)[0]
+			}
+		default:
+			return errors.Errorf("InitialWatermarkEstimatorState fn %v has unexpected number of parameters: %v",
+				n.fn.Fn.Name(), len(n.fn.Param))
+		}
+	}
+	return nil
+}
+
+// Invoke calls InitialWatermarkEstimatorState given a restriction and returns an sdf.RTracker.
+func (n *iwesInvoker) Invoke(rest interface{}, elms *FullValue) interface{} {
+	return n.call(rest, elms)
+}
+
+// Reset zeroes argument entries in the cached slice to allow values to be
+// garbage collected after the bundle ends.
+func (n *iwesInvoker) Reset() {
+	for i := range n.args {
+		n.args[i] = nil
+	}
+}
+
+// wesInvoker is an invoker for WatermarkEstimatorState.
+type wesInvoker struct {
+	fn   *funcx.Fn
+	args []interface{} // Cache to avoid allocating new slices per-element.
+	call func(we sdf.WatermarkEstimator) interface{}
+}
+
+func newWatermarkEstimatorStateInvoker(fn *funcx.Fn) (*wesInvoker, error) {
+	args := []interface{}{}
+	if fn != nil {
+		args = make([]interface{}, len(fn.Param))
+	}
+	n := &wesInvoker{
+		fn:   fn,
+		args: args,
+	}
+	if err := n.initCallFn(); err != nil {
+		return nil, errors.WithContext(err, "sdf WatermarkEstimatorState invoker")
+	}
+	return n, nil
+}
+
+func (n *wesInvoker) initCallFn() error {
+	// If no WatermarkEstimatorState function is defined, we'll use a default implementation that just returns false as the state.
+	if n.fn == nil {
+		n.call = func(we sdf.WatermarkEstimator) interface{} {
+			return false
+		}
+		return nil
+	}
+	// Expects a signature of the form:
+	// (state) sdf.WatermarkEstimator
+	switch fnT := n.fn.Fn.(type) {
+	case reflectx.Func1x1:
+		n.call = func(we sdf.WatermarkEstimator) interface{} {
+			return fnT.Call1x1(we)
+		}
+	default:
+		switch len(n.fn.Param) {
+		case 1:
+			n.call = func(we sdf.WatermarkEstimator) interface{} {
+				n.args[0] = we
+				return n.fn.Fn.Call(n.args)[0]
+			}
+		default:
+			return errors.Errorf("WatermarkEstimatorState fn %v has unexpected number of parameters: %v",
+				n.fn.Fn.Name(), len(n.fn.Param))
+		}
+	}
+	return nil
+}
+
+// Invoke calls WatermarkEstimatorState given a restriction and returns an sdf.RTracker.
+func (n *wesInvoker) Invoke(we sdf.WatermarkEstimator) interface{} {
+	return n.call(we)
+}
+
+// Reset zeroes argument entries in the cached slice to allow values to be
+// garbage collected after the bundle ends.
+func (n *wesInvoker) Reset() {
+	for i := range n.args {
+		n.args[i] = nil
+	}
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
index 10f5f899384..bf959d8ee01 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
@@ -20,6 +20,8 @@ import (
 	"time"
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
 	"github.com/google/go-cmp/cmp"
 )
 
@@ -39,6 +41,12 @@ func TestInvokes(t *testing.T) {
 	}
 	kvsdf := (*graph.SplittableDoFn)(dfn)
 
+	dfn, err = graph.NewDoFn(&VetSdfStatefulWatermark{}, graph.NumMainInputs(graph.MainSingle))
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
+	statefulWeFn := (*graph.SplittableDoFn)(dfn)
+
 	// Tests.
 	t.Run("CreateInitialRestriction Invoker (cirInvoker)", func(t *testing.T) {
 		tests := []struct {
@@ -231,21 +239,71 @@ func TestInvokes(t *testing.T) {
 	})
 
 	t.Run("CreateWatermarkEstimator Invoker (cweInvoker)", func(t *testing.T) {
-		fn := sdf.CreateWatermarkEstimatorFn()
-		invoker, err := newCreateWatermarkEstimatorInvoker(fn)
+		tests := []struct {
+			name  string
+			sdf   *graph.SplittableDoFn
+			state int
+			want  VetWatermarkEstimator
+		}{
+			{
+				name:  "Non-stateful",
+				sdf:   sdf,
+				state: 1,
+				want:  VetWatermarkEstimator{State: -1},
+			}, {
+				name:  "Stateful",
+				sdf:   statefulWeFn,
+				state: 11,
+				want:  VetWatermarkEstimator{State: 11},
+			},
+		}
+
+		for _, test := range tests {
+			test := test
+			fn := test.sdf.CreateWatermarkEstimatorFn()
+			t.Run(test.name, func(t *testing.T) {
+				invoker, err := newCreateWatermarkEstimatorInvoker(fn)
+				if err != nil {
+					t.Fatalf("newCreateWatermarkEstimatorInvoker failed: %v", err)
+				}
+				got := invoker.Invoke(test.state)
+				want := &test.want
+				if !cmp.Equal(got, want) {
+					t.Errorf("Invoke() has incorrect output: got: %v, want: %v", got, want)
+				}
+				invoker.Reset()
+				for i, arg := range invoker.args {
+					if arg != nil {
+						t.Errorf("Reset() failed to empty all args. args[%v] = %v", i, arg)
+					}
+				}
+			})
+		}
+	})
+
+	t.Run("InitialWatermarkEstimatorState Invoker (iwesInvoker)", func(t *testing.T) {
+		fn := statefulWeFn.InitialWatermarkEstimatorStateFn()
+		invoker, err := newInitialWatermarkEstimatorStateInvoker(fn)
 		if err != nil {
-			t.Fatalf("newCreateWatermarkEstimatorInvoker failed: %v", err)
+			t.Fatalf("newInitialWatermarkEstimatorStateInvoker failed: %v", err)
 		}
-		got := invoker.Invoke()
-		want := &VetWatermarkEstimator{}
-		if !cmp.Equal(got, want) {
+		got := invoker.Invoke(&VetRestriction{ID: "Sdf"}, &FullValue{Elm: 1, Timestamp: mtime.ZeroTimestamp})
+		want := 1
+		if got != want {
 			t.Errorf("Invoke() has incorrect output: got: %v, want: %v", got, want)
 		}
-		invoker.Reset()
-		for i, arg := range invoker.args {
-			if arg != nil {
-				t.Errorf("Reset() failed to empty all args. args[%v] = %v", i, arg)
-			}
+	})
+
+	t.Run("WatermarkEstimatorState Invoker (wesInvoker)", func(t *testing.T) {
+		fn := statefulWeFn.WatermarkEstimatorStateFn()
+		invoker, err := newWatermarkEstimatorStateInvoker(fn)
+		if err != nil {
+			t.Fatalf("newWatermarkEstimatorStateInvoker failed: %v", err)
+		}
+		got := invoker.Invoke(&VetWatermarkEstimator{State: 11})
+		want := 11
+		if got != want {
+			t.Errorf("Invoke() has incorrect output: got: %v, want: %v", got, want)
 		}
 	})
 }
@@ -288,7 +346,9 @@ func (rt *VetRTracker) TrySplit(_ float64) (interface{}, interface{}, error) {
 	return nil, nil, nil
 }
 
-type VetWatermarkEstimator struct{}
+type VetWatermarkEstimator struct {
+	State int
+}
 
 func (e *VetWatermarkEstimator) CurrentWatermark() time.Time {
 	return time.Date(2022, time.January, 1, 1, 0, 0, 0, time.UTC)
@@ -340,7 +400,7 @@ func (fn *VetSdf) CreateTracker(rest *VetRestriction) *VetRTracker {
 
 // CreateWatermarkEstimator creates a watermark estimator to be used by the Sdf
 func (fn *VetSdf) CreateWatermarkEstimator() *VetWatermarkEstimator {
-	return &VetWatermarkEstimator{}
+	return &VetWatermarkEstimator{State: -1}
 }
 
 // ProcessElement emits the restriction from the restriction tracker it
@@ -356,6 +416,57 @@ func (fn *VetSdf) ProcessElement(rt *VetRTracker, i int, emit func(*VetRestricti
 	emit(rest)
 }
 
+type VetSdfStatefulWatermark struct {
+}
+
+func (fn *VetSdfStatefulWatermark) CreateInitialRestriction(i int) *VetRestriction {
+	return &VetRestriction{ID: "Sdf", Val: i, CreateRest: true}
+}
+
+func (fn *VetSdfStatefulWatermark) SplitRestriction(i int, rest *VetRestriction) []*VetRestriction {
+	rest.SplitRest = true
+	rest.Val = i
+
+	rest1 := rest.copy()
+	rest1.ID += ".1"
+	rest2 := rest.copy()
+	rest2.ID += ".2"
+
+	return []*VetRestriction{&rest1, &rest2}
+}
+
+func (fn *VetSdfStatefulWatermark) RestrictionSize(i int, rest *VetRestriction) float64 {
+	rest.Key = nil
+	rest.Val = i
+	rest.RestSize = true
+	return (float64)(i)
+}
+
+func (fn *VetSdfStatefulWatermark) CreateTracker(rest *VetRestriction) *VetRTracker {
+	rest.CreateTracker = true
+	return &VetRTracker{rest}
+}
+
+func (fn *VetSdfStatefulWatermark) InitialWatermarkEstimatorState(_ typex.EventTime, _ *VetRestriction, element int) int {
+	return 1
+}
+
+func (fn *VetSdfStatefulWatermark) CreateWatermarkEstimator(state int) *VetWatermarkEstimator {
+	return &VetWatermarkEstimator{State: state}
+}
+
+func (fn *VetSdfStatefulWatermark) WatermarkEstimatorState(e *VetWatermarkEstimator) int {
+	return e.State
+}
+
+func (fn *VetSdfStatefulWatermark) ProcessElement(rt *VetRTracker, i int, emit func(*VetRestriction)) {
+	rest := rt.Rest
+	rest.Key = nil
+	rest.Val = i
+	rest.ProcessElm = true
+	emit(rest)
+}
+
 // VetKvSdf runs an SDF In order to test that these methods get called properly,
 // each method will flip the corresponding flag in the passed in VetRestriction,
 // overwrite the restriction's Key and Val with the last seen input elements,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index e5729bfb031..1f9c56bdbac 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -64,6 +64,10 @@ func TestSdfNodes(t *testing.T) {
 	if err != nil {
 		t.Fatalf("invalid function: %v", err)
 	}
+	statefulWeFn, err := graph.NewDoFn(&VetSdfStatefulWatermark{}, graph.NumMainInputs(graph.MainSingle))
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
 
 	// Validate PairWithRestriction matches its contract and properly invokes
 	// SDF method CreateInitialRestriction.
@@ -90,7 +94,34 @@ func TestSdfNodes(t *testing.T) {
 						Timestamp: testTimestamp,
 						Windows:   testWindows,
 					},
-					Elm2:      &VetRestriction{ID: "Sdf", CreateRest: true, Val: 1},
+					Elm2: &FullValue{
+						Elm:  &VetRestriction{ID: "Sdf", CreateRest: true, Val: 1},
+						Elm2: false,
+					},
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+			},
+			{
+				name: "SingleElemStatefulWatermarkEstimating",
+				fn:   statefulWeFn,
+				in: FullValue{
+					Elm:       1,
+					Elm2:      nil,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+				want: FullValue{
+					Elm: &FullValue{
+						Elm:       1,
+						Elm2:      nil,
+						Timestamp: testTimestamp,
+						Windows:   testWindows,
+					},
+					Elm2: &FullValue{
+						Elm:  &VetRestriction{ID: "Sdf", CreateRest: true, Val: 1},
+						Elm2: 1,
+					},
 					Timestamp: testTimestamp,
 					Windows:   testWindows,
 				},
@@ -111,7 +142,10 @@ func TestSdfNodes(t *testing.T) {
 						Timestamp: testTimestamp,
 						Windows:   testWindows,
 					},
-					Elm2:      &VetRestriction{ID: "KvSdf", CreateRest: true, Key: 1, Val: 2},
+					Elm2: &FullValue{
+						Elm:  &VetRestriction{ID: "KvSdf", CreateRest: true, Key: 1, Val: 2},
+						Elm2: false,
+					},
 					Timestamp: testTimestamp,
 					Windows:   testWindows,
 				},
@@ -154,7 +188,10 @@ func TestSdfNodes(t *testing.T) {
 						Timestamp: testTimestamp,
 						Windows:   testWindows,
 					},
-					Elm2:      &VetRestriction{ID: "Sdf"},
+					Elm2: &FullValue{
+						Elm:  &VetRestriction{ID: "Sdf"},
+						Elm2: 1,
+					},
 					Timestamp: testTimestamp,
 					Windows:   testWindows,
 				},
@@ -167,7 +204,10 @@ func TestSdfNodes(t *testing.T) {
 								Timestamp: testTimestamp,
 								Windows:   testWindows,
 							},
-							Elm2: &VetRestriction{ID: "Sdf.1", SplitRest: true, RestSize: true, Val: 1},
+							Elm2: &FullValue{
+								Elm:  &VetRestriction{ID: "Sdf.1", SplitRest: true, RestSize: true, Val: 1},
+								Elm2: 1,
+							},
 						},
 						Elm2:      1.0,
 						Timestamp: testTimestamp,
@@ -181,7 +221,10 @@ func TestSdfNodes(t *testing.T) {
 								Timestamp: testTimestamp,
 								Windows:   testWindows,
 							},
-							Elm2: &VetRestriction{ID: "Sdf.2", SplitRest: true, RestSize: true, Val: 1},
+							Elm2: &FullValue{
+								Elm:  &VetRestriction{ID: "Sdf.2", SplitRest: true, RestSize: true, Val: 1},
+								Elm2: 1,
+							},
 						},
 						Elm2:      1.0,
 						Timestamp: testTimestamp,
@@ -199,7 +242,10 @@ func TestSdfNodes(t *testing.T) {
 						Timestamp: testTimestamp,
 						Windows:   testWindows,
 					},
-					Elm2:      &VetRestriction{ID: "KvSdf"},
+					Elm2: &FullValue{
+						Elm:  &VetRestriction{ID: "KvSdf"},
+						Elm2: false,
+					},
 					Timestamp: testTimestamp,
 					Windows:   testWindows,
 				},
@@ -212,7 +258,10 @@ func TestSdfNodes(t *testing.T) {
 								Timestamp: testTimestamp,
 								Windows:   testWindows,
 							},
-							Elm2: &VetRestriction{ID: "KvSdf.1", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+							Elm2: &FullValue{
+								Elm:  &VetRestriction{ID: "KvSdf.1", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+								Elm2: false,
+							},
 						},
 						Elm2:      3.0,
 						Timestamp: testTimestamp,
@@ -226,7 +275,10 @@ func TestSdfNodes(t *testing.T) {
 								Timestamp: testTimestamp,
 								Windows:   testWindows,
 							},
-							Elm2: &VetRestriction{ID: "KvSdf.2", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+							Elm2: &FullValue{
+								Elm:  &VetRestriction{ID: "KvSdf.2", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+								Elm2: false,
+							},
 						},
 						Elm2:      3.0,
 						Timestamp: testTimestamp,
@@ -244,7 +296,10 @@ func TestSdfNodes(t *testing.T) {
 						Timestamp: testTimestamp,
 						Windows:   testWindows,
 					},
-					Elm2:      &VetRestriction{ID: "Sdf"},
+					Elm2: &FullValue{
+						Elm:  &VetRestriction{ID: "Sdf"},
+						Elm2: false,
+					},
 					Timestamp: testTimestamp,
 					Windows:   testWindows,
 				},
@@ -296,7 +351,10 @@ func TestSdfNodes(t *testing.T) {
 						Timestamp: testTimestamp,
 						Windows:   testWindows,
 					},
-					Elm2:      offsetrange.Restriction{Start: 0, End: 4},
+					Elm2: &FullValue{
+						Elm:  offsetrange.Restriction{Start: 0, End: 4},
+						Elm2: false,
+					},
 					Timestamp: testTimestamp,
 					Windows:   testWindows,
 				},
@@ -338,8 +396,33 @@ func TestSdfNodes(t *testing.T) {
 				fn:   dfn,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
+					},
+					Elm2:      1.0,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+				want: FullValue{
+					Elm:       &VetRestriction{ID: "Sdf", CreateTracker: true, ProcessElm: true, Val: 1},
+					Elm2:      nil,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+			},
+			{
+				name: "SingleElemStatefulWatermarkEstimating",
+				fn:   statefulWeFn,
+				in: FullValue{
+					Elm: &FullValue{
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: 1,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -363,7 +446,10 @@ func TestSdfNodes(t *testing.T) {
 							Timestamp: testTimestamp,
 							Windows:   testWindows,
 						},
-						Elm2: &VetRestriction{ID: "KvSdf"},
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "KvSdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      3.0,
 					Timestamp: testTimestamp,
@@ -495,6 +581,10 @@ func TestAsSplittableUnit(t *testing.T) {
 	if err != nil {
 		t.Fatalf("invalid function: %v", err)
 	}
+	statefulWeFn, err := graph.NewDoFn(&VetSdfStatefulWatermark{}, graph.NumMainInputs(graph.MainSingle))
+	if err != nil {
+		t.Fatalf("invalid function: %v", err)
+	}
 	multiWindows := []typex.Window{
 		window.IntervalWindow{Start: 10, End: 20},
 		window.IntervalWindow{Start: 11, End: 21},
@@ -537,8 +627,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				// but the element is still built to be valid.
 				elm := FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -549,7 +642,7 @@ func TestAsSplittableUnit(t *testing.T) {
 				n := &ParDo{UID: 1, Fn: dfn, Out: []Node{}}
 				node := &ProcessSizedElementsAndRestrictions{PDo: n}
 				node.rt = &SplittableUnitRTracker{
-					VetRTracker: VetRTracker{Rest: elm.Elm.(*FullValue).Elm2.(*VetRestriction)},
+					VetRTracker: VetRTracker{Rest: elm.Elm.(*FullValue).Elm2.(*FullValue).Elm.(*VetRestriction)},
 					Done:        test.doneWork,
 					Remaining:   test.remainingWork,
 					ThisIsDone:  false,
@@ -589,8 +682,52 @@ func TestAsSplittableUnit(t *testing.T) {
 				frac: 0.5,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
+					},
+					Elm2:      1.0,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				},
+				wantPrimaries: []*FullValue{{
+					Elm: &FullValue{
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+							Elm2: false,
+						},
+					},
+					Elm2:      1.0,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				}},
+				wantResiduals: []*FullValue{{
+					Elm: &FullValue{
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+							Elm2: false,
+						},
+					},
+					Elm2:      1.0,
+					Timestamp: testTimestamp,
+					Windows:   testWindows,
+				}},
+			},
+			{
+				name: "SingleElemStatefulWatermarkEstimating",
+				fn:   statefulWeFn,
+				frac: 0.5,
+				in: FullValue{
+					Elm: &FullValue{
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: 0,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -598,8 +735,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				},
 				wantPrimaries: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+							Elm2: 1,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -607,8 +747,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				}},
 				wantResiduals: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+							Elm2: 1,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -625,7 +768,10 @@ func TestAsSplittableUnit(t *testing.T) {
 							Elm:  1,
 							Elm2: 2,
 						},
-						Elm2: &VetRestriction{ID: "KvSdf"},
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "KvSdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      3.0,
 					Timestamp: testTimestamp,
@@ -637,7 +783,10 @@ func TestAsSplittableUnit(t *testing.T) {
 							Elm:  1,
 							Elm2: 2,
 						},
-						Elm2: &VetRestriction{ID: "KvSdf.1", RestSize: true, Key: 1, Val: 2},
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "KvSdf.1", RestSize: true, Key: 1, Val: 2},
+							Elm2: false,
+						},
 					},
 					Elm2:      3.0,
 					Timestamp: testTimestamp,
@@ -649,7 +798,10 @@ func TestAsSplittableUnit(t *testing.T) {
 							Elm:  1,
 							Elm2: 2,
 						},
-						Elm2: &VetRestriction{ID: "KvSdf.2", RestSize: true, Key: 1, Val: 2},
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "KvSdf.2", RestSize: true, Key: 1, Val: 2},
+							Elm2: false,
+						},
 					},
 					Elm2:      3.0,
 					Timestamp: testTimestamp,
@@ -663,8 +815,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				frac:   0.5,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -681,8 +836,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				frac: 0.125, // Should be in the middle of the first (current) window.
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -690,8 +848,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				},
 				wantPrimaries: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -699,16 +860,22 @@ func TestAsSplittableUnit(t *testing.T) {
 				}},
 				wantResiduals: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
 					Windows:   testMultiWindows[0:1],
 				}, {
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -723,8 +890,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				frac: 0.55,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -732,8 +902,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				},
 				wantPrimaries: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -741,8 +914,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				}},
 				wantResiduals: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -758,8 +934,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				doneRt: true,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -767,8 +946,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				},
 				wantPrimaries: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -776,8 +958,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				}},
 				wantResiduals: []*FullValue{{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -792,8 +977,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				frac: 0.95, // Should round to end of element and cause a no-op.
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -807,10 +995,10 @@ func TestAsSplittableUnit(t *testing.T) {
 			test := test
 			t.Run(test.name, func(t *testing.T) {
 				// Setup, create transforms, inputs, and desired outputs.
-				n := &ParDo{UID: 1, Fn: test.fn, Out: []Node{}}
+				n := &ParDo{UID: 1, Fn: test.fn, Out: []Node{}, we: &VetWatermarkEstimator{State: 1}}
 				node := &ProcessSizedElementsAndRestrictions{PDo: n}
 				node.rt = &SplittableUnitRTracker{
-					VetRTracker: VetRTracker{Rest: test.in.Elm.(*FullValue).Elm2.(*VetRestriction)},
+					VetRTracker: VetRTracker{Rest: test.in.Elm.(*FullValue).Elm2.(*FullValue).Elm.(*VetRestriction)},
 					Done:        0,
 					Remaining:   1.0,
 					ThisIsDone:  test.doneRt,
@@ -850,8 +1038,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				fn:   pdfn,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &offsetrange.Restriction{Start: 0, End: 4},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &offsetrange.Restriction{Start: 0, End: 4},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -863,8 +1054,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				fn:   rdfn,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &offsetrange.Restriction{Start: 0, End: 4},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &offsetrange.Restriction{Start: 0, End: 4},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -878,7 +1072,7 @@ func TestAsSplittableUnit(t *testing.T) {
 				// Setup, create transforms, inputs, and desired outputs.
 				n := &ParDo{UID: 1, Fn: test.fn, Out: []Node{}}
 				node := &ProcessSizedElementsAndRestrictions{PDo: n}
-				node.rt = sdf.RTracker(offsetrange.NewTracker(*test.in.Elm.(*FullValue).Elm2.(*offsetrange.Restriction)))
+				node.rt = sdf.RTracker(offsetrange.NewTracker(*test.in.Elm.(*FullValue).Elm2.(*FullValue).Elm.(*offsetrange.Restriction)))
 				node.elm = &test.in
 				node.numW = len(test.in.Windows)
 				node.currW = 0
@@ -911,8 +1105,11 @@ func TestAsSplittableUnit(t *testing.T) {
 				fn:   dfn,
 				in: FullValue{
 					Elm: &FullValue{
-						Elm:  1,
-						Elm2: &VetRestriction{ID: "Sdf"},
+						Elm: 1,
+						Elm2: &FullValue{
+							Elm:  &VetRestriction{ID: "Sdf"},
+							Elm2: false,
+						},
 					},
 					Elm2:      1.0,
 					Timestamp: testTimestamp,
@@ -979,8 +1176,11 @@ func TestMultiWindowProcessing(t *testing.T) {
 	// Create a plan with a single valid element as input to ProcessElement.
 	in := FullValue{
 		Elm: &FullValue{
-			Elm:  1,
-			Elm2: offsetrange.Restriction{Start: 0, End: 4},
+			Elm: 1,
+			Elm2: &FullValue{
+				Elm:  offsetrange.Restriction{Start: 0, End: 4},
+				Elm2: false,
+			},
 		},
 		Elm2:      4.0,
 		Timestamp: testTimestamp,
diff --git a/sdks/go/pkg/beam/core/runtime/genx/genx.go b/sdks/go/pkg/beam/core/runtime/genx/genx.go
index 422ee2900c7..b5bf766c268 100644
--- a/sdks/go/pkg/beam/core/runtime/genx/genx.go
+++ b/sdks/go/pkg/beam/core/runtime/genx/genx.go
@@ -117,6 +117,12 @@ func handleDoFn(fn *graph.DoFn, c cache) {
 	}
 	c.pullMethod(sdf.CreateWatermarkEstimatorFn())
 	c.regType(sdf.WatermarkEstimatorT())
+	if !sdf.IsStatefulWatermarkEstimating() {
+		return
+	}
+	c.pullMethod(sdf.InitialWatermarkEstimatorStateFn())
+	c.pullMethod(sdf.WatermarkEstimatorStateFn())
+	c.regType(sdf.WatermarkEstimatorStateT())
 }
 
 func handleCombineFn(fn *graph.CombineFn, c cache) {
diff --git a/sdks/go/pkg/beam/core/runtime/genx/genx_test.go b/sdks/go/pkg/beam/core/runtime/genx/genx_test.go
index cc219d6f295..24f96a5bc7a 100644
--- a/sdks/go/pkg/beam/core/runtime/genx/genx_test.go
+++ b/sdks/go/pkg/beam/core/runtime/genx/genx_test.go
@@ -22,6 +22,7 @@ import (
 
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
@@ -41,6 +42,7 @@ func TestRegisterDoFn(t *testing.T) {
 	tO := reflect.TypeOf((*O)(nil)).Elem()
 	tRt := reflect.TypeOf((*sdf.LockRTracker)(nil)).Elem()
 	tWe := reflect.TypeOf((*sdf.WallTimeWatermarkEstimator)(nil)).Elem()
+	tWes := reflect.TypeOf((*WatermarkEstimatorState)(nil)).Elem()
 
 	tests := []struct {
 		name   string
@@ -69,7 +71,7 @@ func TestRegisterDoFn(t *testing.T) {
 		{"DoFn01 pointer reflect", reflect.TypeOf(&DoFn01{}), true, false, []reflect.Type{tDoFn01, tR, tS}},
 		{"DoFn02 reflect - filtered types", tDoFn02, true, false, []reflect.Type{tDoFn02}},
 		{"CombineFn01 reflect - combine methods", tCmbFn01, true, false, []reflect.Type{tCmbFn01, tA, tI, tO}},
-		{"DoFn03 reflect - sdf methods", tDoFn03, true, false, []reflect.Type{tDoFn03, tRt, tWe, tR}},
+		{"DoFn03 reflect - sdf methods", tDoFn03, true, false, []reflect.Type{typex.EventTimeType, tDoFn03, tRt, tWe, tWes, tR}},
 		{"DoFn04 reflect - containers", tDoFn04, true, false, []reflect.Type{tDoFn04, tR, tS, tT, tA, tI, tO}},
 	}
 
@@ -225,10 +227,20 @@ func (fn *DoFn03) CreateTracker(rest R) *sdf.LockRTracker {
 	return &sdf.LockRTracker{Rt: RT{}}
 }
 
-func (fn *DoFn03) CreateWatermarkEstimator() *sdf.WallTimeWatermarkEstimator {
+type WatermarkEstimatorState struct{}
+
+func (fn *DoFn03) WatermarkEstimatorState(estimator *sdf.WallTimeWatermarkEstimator) WatermarkEstimatorState {
+	return WatermarkEstimatorState{}
+}
+
+func (fn *DoFn03) CreateWatermarkEstimator(state WatermarkEstimatorState) *sdf.WallTimeWatermarkEstimator {
 	return &sdf.WallTimeWatermarkEstimator{}
 }
 
+func (fn *DoFn03) InitialWatermarkEstimatorState(ts typex.EventTime, rest R, s string) WatermarkEstimatorState {
+	return WatermarkEstimatorState{}
+}
+
 type DoFn04 struct{}
 
 func (*DoFn04) ProcessElement([4]R, map[S]T, func(*O) bool, func() func(*I) bool, func([]A)) {
diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go
index 45176febf31..aad86b6a02e 100644
--- a/sdks/go/pkg/beam/pardo.go
+++ b/sdks/go/pkg/beam/pardo.go
@@ -17,6 +17,8 @@ package beam
 
 import (
 	"fmt"
+	"reflect"
+
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
@@ -64,9 +66,16 @@ func TryParDo(s Scope, dofn interface{}, col PCollection, opts ...Option) ([]PCo
 	}
 
 	var rc *coder.Coder
+	// Sdfs will always encode restrictions as KV<restriction, watermark state | bool(false)>
 	if fn.IsSplittable() {
 		sdf := (*graph.SplittableDoFn)(fn)
-		rc, err = inferCoder(typex.New(sdf.RestrictionT()))
+		restT := typex.New(sdf.RestrictionT())
+		// If no watermark estimator state, use boolean as a placeholder
+		weT := typex.New(reflect.TypeOf(true))
+		if sdf.IsStatefulWatermarkEstimating() {
+			weT = typex.New(sdf.WatermarkEstimatorStateT())
+		}
+		rc, err = inferCoder(typex.NewKV(restT, weT))
 		if err != nil {
 			return nil, addParDoCtx(err, s)
 		}