This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e91f92e678e [BEAM-11105] Stateful watermark estimation (#17374)
e91f92e678e is described below

commit e91f92e678ec589888bb0d691039c57e3aa88c88
Author: Danny McCormick <[email protected]>
AuthorDate: Wed Apr 27 00:03:22 2022 -0400

    [BEAM-11105] Stateful watermark estimation (#17374)
---
 sdks/go/pkg/beam/core/graph/fn.go                  | 217 +++++++++++---
 sdks/go/pkg/beam/core/graph/fn_test.go             | 172 ++++++++++-
 .../go/pkg/beam/core/runtime/exec/dynsplit_test.go |  14 +-
 sdks/go/pkg/beam/core/runtime/exec/sdf.go          | 126 +++++---
 sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go | 173 ++++++++++-
 .../beam/core/runtime/exec/sdf_invokers_test.go    | 137 ++++++++-
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     | 320 +++++++++++++++++----
 sdks/go/pkg/beam/core/runtime/genx/genx.go         |   6 +
 sdks/go/pkg/beam/core/runtime/genx/genx_test.go    |  16 +-
 sdks/go/pkg/beam/pardo.go                          |  11 +-
 10 files changed, 1018 insertions(+), 174 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go 
b/sdks/go/pkg/beam/core/graph/fn.go
index 69460ee9f4a..775a3dfe9f3 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -168,7 +168,9 @@ const (
        restrictionSizeName          = "RestrictionSize"
        createTrackerName            = "CreateTracker"
 
-       createWatermarkEstimatorName = "CreateWatermarkEstimator"
+       createWatermarkEstimatorName       = "CreateWatermarkEstimator"
+       initialWatermarkEstimatorStateName = "InitialWatermarkEstimatorState"
+       watermarkEstimatorStateName        = "WatermarkEstimatorState"
 
        createAccumulatorName = "CreateAccumulator"
        addInputName          = "AddInput"
@@ -190,6 +192,8 @@ var doFnNames = []string{
        restrictionSizeName,
        createTrackerName,
        createWatermarkEstimatorName,
+       initialWatermarkEstimatorStateName,
+       watermarkEstimatorStateName,
 }
 
 var requiredSdfNames = []string{
@@ -201,6 +205,8 @@ var requiredSdfNames = []string{
 
 var watermarkEstimationNames = []string{
        createWatermarkEstimatorName,
+       initialWatermarkEstimatorStateName,
+       watermarkEstimatorStateName,
 }
 
 var combineFnNames = []string{
@@ -314,7 +320,7 @@ func (f *SplittableDoFn) IsWatermarkEstimating() bool {
        return ok
 }
 
-// createWatermarkEstimatorFn returns the "createWatermarkEstimator" function, 
if present
+// CreateWatermarkEstimatorFn returns the "createWatermarkEstimator" function, 
if present
 func (f *SplittableDoFn) CreateWatermarkEstimatorFn() *funcx.Fn {
        return f.methods[createWatermarkEstimatorName]
 }
@@ -324,6 +330,27 @@ func (f *SplittableDoFn) WatermarkEstimatorT() 
reflect.Type {
        return f.CreateWatermarkEstimatorFn().Ret[0].T
 }
 
+// IsStatefulWatermarkEstimating returns whether the DoFn implements custom 
watermark state.
+func (f *SplittableDoFn) IsStatefulWatermarkEstimating() bool {
+       _, ok := f.methods[watermarkEstimatorStateName]
+       return ok
+}
+
+// InitialWatermarkEstimatorStateFn returns the 
"InitialWatermarkEstimatorState" function, if present
+func (f *SplittableDoFn) InitialWatermarkEstimatorStateFn() *funcx.Fn {
+       return f.methods[initialWatermarkEstimatorStateName]
+}
+
+// WatermarkEstimatorStateFn returns the "WatermarkEstimatorState" function, 
if present
+func (f *SplittableDoFn) WatermarkEstimatorStateFn() *funcx.Fn {
+       return f.methods[watermarkEstimatorStateName]
+}
+
+// WatermarkEstimatorStateT returns the type of the watermark estimator state 
from the SDF
+func (f *SplittableDoFn) WatermarkEstimatorStateT() reflect.Type {
+       return f.WatermarkEstimatorStateFn().Ret[0].T
+}
+
 // TODO(herohde) 5/19/2017: we can sometimes detect whether the main input 
must be
 // a KV or not based on the other signatures (unless we're more loose about 
which
 // sideinputs are present). Bind should respect that.
@@ -519,7 +546,7 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
        }
 
        if isWatermarkEstimating {
-               err := validateWatermarkSig(fn)
+               err := validateWatermarkSig(fn, int(numMainIn))
                if err != nil {
                        return nil, addContext(err, fn)
                }
@@ -852,11 +879,11 @@ func validateSdfSigTypes(fn *Fn, num int) error {
                method := fn.methods[name]
                switch name {
                case createInitialRestrictionName:
-                       if err := validateSdfElementT(fn, 
createInitialRestrictionName, method, num); err != nil {
+                       if err := validateSdfElementT(fn, 
createInitialRestrictionName, method, num, 0); err != nil {
                                return err
                        }
                case splitRestrictionName:
-                       if err := validateSdfElementT(fn, splitRestrictionName, 
method, num); err != nil {
+                       if err := validateSdfElementT(fn, splitRestrictionName, 
method, num, 0); err != nil {
                                return err
                        }
                        if method.Param[num].T != restrictionT {
@@ -877,7 +904,7 @@ func validateSdfSigTypes(fn *Fn, num int) error {
                                        splitRestrictionName, 0, 
method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, 
splitRestrictionName)
                        }
                case restrictionSizeName:
-                       if err := validateSdfElementT(fn, restrictionSizeName, 
method, num); err != nil {
+                       if err := validateSdfElementT(fn, restrictionSizeName, 
method, num, 0); err != nil {
                                return err
                        }
                        if method.Param[num].T != restrictionT {
@@ -928,15 +955,15 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 
 // validateSdfElementT validates that element types in an SDF method are
 // consistent with the ProcessElement method. This method assumes that the
-// first 'num' parameters to the SDF method are the elements.
-func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int) error 
{
+// first 'num' parameters starting with startIndex are the elements.
+func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int, 
startIndex int) error {
        // ProcessElement is the most canonical source of the element type. We 
can
        // processFn is valid by this point and skip unnecessary validation.
        processFn := fn.methods[processElementName]
        pos, _, _ := processFn.Inputs()
 
        for i := 0; i < num; i++ {
-               if method.Param[i].T != processFn.Param[pos+i].T {
+               if method.Param[i+startIndex].T != processFn.Param[pos+i].T {
                        err := errors.Errorf("mismatched element type in method 
%v, param %v. got: %v, want: %v",
                                name, i, method.Param[i].T, 
processFn.Param[pos+i].T)
                        return errors.SetTopLevelMsgf(err, "Mismatched element 
type in method %v, "+
@@ -961,45 +988,163 @@ func validateIsWatermarkEstimating(fn *Fn, isSdf bool) 
(bool, error) {
 }
 
 // validateWatermarkSig validates that all watermark related functions are 
valid
-func validateWatermarkSig(fn *Fn) error {
-       paramRange := map[string][]int{
-               createWatermarkEstimatorName: []int{0, 0},
-       }
+func validateWatermarkSig(fn *Fn, numMainIn int) error {
        returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF 
methods.
 
        watermarkEstimatorT := 
reflect.TypeOf((*sdf.WatermarkEstimator)(nil)).Elem()
+       method := fn.methods[createWatermarkEstimatorName]
+
+       if len(method.Param) > 1 {
+               err := errors.Errorf("unexpected number of params in method %v. 
got: %v, want number in range: 0 to 1",
+                       createWatermarkEstimatorName, len(method.Param))
+               return errors.SetTopLevelMsgf(err, "unexpected number of 
parameters in method %v. "+
+                       "got: %v, want number in range: 0 to 1. Check that the 
signature conforms to the expected signature for %v.",
+                       createWatermarkEstimatorName, len(method.Param), 
createWatermarkEstimatorName)
+       } else if len(method.Param) == 1 {
+               err := validateStatefulWatermarkSig(fn, numMainIn)
+               if err != nil {
+                       return err
+               }
+       } else {
+               if _, ok := fn.methods[initialWatermarkEstimatorStateName]; ok {
+                       err := errors.Errorf("stateful watermark estimation 
method %v is present, "+
+                               "but CreateWatermarkEstimator doesn't take in a 
state parameter.", initialWatermarkEstimatorStateName)
+                       return err
+               }
+               if _, ok := fn.methods[watermarkEstimatorStateName]; ok {
+                       err := errors.Errorf("stateful watermark estimation 
method %v is present, "+
+                               "but CreateWatermarkEstimator doesn't take in a 
state parameter.", watermarkEstimatorStateName)
+                       return err
+               }
+       }
+
+       if len(method.Ret) != returnNum {
+               err := errors.Errorf("unexpected number of returns in method 
%v. got: %v, want: %v",
+                       createWatermarkEstimatorName, len(method.Ret), 
returnNum)
+               return errors.SetTopLevelMsgf(err, "unexpected number of return 
values in method %v. "+
+                       "got: %v, want: %v. Check that the signature conforms 
to the expected signature for %v.",
+                       createWatermarkEstimatorName, len(method.Ret), 
returnNum, createWatermarkEstimatorName)
+       } else if !method.Ret[0].T.Implements(watermarkEstimatorT) {
+               err := errors.Errorf("invalid output type in method %v, return 
%v: %v does not implement sdf.WatermarkEstimator",
+                       createWatermarkEstimatorName, 0, method.Ret[0].T)
+               return errors.SetTopLevelMsgf(err, "invalid output type in 
method %v, "+
+                       "return value at index %v (type: %v). Output of method 
%v must implement sdf.WatermarkEstimator.",
+                       createWatermarkEstimatorName, 0, method.Ret[0].T, 
createWatermarkEstimatorName)
+       }
 
+       return nil
+}
+
+func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
+       // Store missing method names so we can output them to the user if 
validation fails.
+       var missing []string
        for _, name := range watermarkEstimationNames {
-               if method, ok := fn.methods[name]; ok {
-                       if len(method.Param) < paramRange[name][0] || 
len(method.Param) > paramRange[name][1] {
-                               err := errors.Errorf("unexpected number of 
params in method %v. got: %v, want number in range: %v to %v",
-                                       name, len(method.Param), 
paramRange[name][0], paramRange[name][1])
-                               return errors.SetTopLevelMsgf(err, "Unexpected 
number of parameters in method %v. "+
-                                       "Got: %v, Want number in range: %v to 
%v. Check that the signature conforms to the expected signature for %v, "+
+               _, ok := fn.methods[name]
+               if !ok {
+                       missing = append(missing, name)
+               }
+       }
+       if len(missing) > 0 {
+               err := errors.Errorf("not all required stateful watermark 
estimation methods are present, "+
+                       "but CreateWatermarkEstimator takes in a state 
parameter. Missing methods: %v", missing)
+               return err
+       }
+
+       restT := fn.methods[createInitialRestrictionName].Ret[0].T
+       watermarkStateT := fn.methods[createWatermarkEstimatorName].Param[0].T
+       watermarkEstimatorT := fn.methods[createWatermarkEstimatorName].Ret[0].T
+
+       // If number of main inputs is ambiguous, we check for consistency 
against
+       // CreateInitialRestriction.
+       if numMainIn == int(MainUnknown) {
+               initialRestFn := fn.methods[createInitialRestrictionName]
+               paramNum := len(initialRestFn.Param)
+               switch paramNum {
+               case int(MainSingle), int(MainKv):
+                       numMainIn = paramNum
+               }
+       }
+
+       for _, name := range watermarkEstimationNames {
+               method := fn.methods[name]
+               switch name {
+               case initialWatermarkEstimatorStateName:
+                       if len(method.Param) != numMainIn+2 {
+                               err := errors.Errorf("unexpected number of 
params in method %v. got: %v, want: %v",
+                                       initialWatermarkEstimatorStateName, 
len(method.Param), numMainIn+2)
+                               return errors.SetTopLevelMsgf(err, "unexpected 
number of parameters in method %v. "+
+                                       "got: %v, want: %v. Check that the 
signature conforms to the expected signature for %v, "+
                                        "and that elements in SDF method 
parameters match elements in %v.",
-                                       name, len(method.Param), 
paramRange[name][0], paramRange[name][1], name, processElementName)
+                                       initialWatermarkEstimatorStateName, 
len(method.Param), numMainIn+2, initialWatermarkEstimatorStateName, 
processElementName)
+                       }
+                       if method.Param[0].T != typex.EventTimeType {
+                               err := errors.Errorf("unexpected parameter type 
in method %v, param %v. got: %v, want: %v",
+                                       initialWatermarkEstimatorStateName, 0, 
method.Param[0].T, typex.EventTimeType)
+                               return errors.SetTopLevelMsgf(err, "mismatched 
event time type in method %v, "+
+                                       "parameter at index %v. got: %v, want: 
%v.",
+                                       initialWatermarkEstimatorStateName, 0, 
method.Param[0].T, typex.EventTimeType)
                        }
-                       if len(method.Ret) != returnNum {
-                               err := errors.Errorf("unexpected number of 
returns in method %v. got: %v, want: %v",
-                                       name, len(method.Ret), returnNum)
-                               return errors.SetTopLevelMsgf(err, "Unexpected 
number of return values in method %v. "+
-                                       "Got: %v, Want: %v. Check that the 
signature conforms to the expected signature for %v.",
-                                       name, len(method.Ret), returnNum, name)
+                       if method.Param[1].T != restT {
+                               err := errors.Errorf("mismatched restriction 
type in method %v, param %v. got: %v, want: %v",
+                                       initialWatermarkEstimatorStateName, 1, 
method.Param[1].T, restT)
+                               return errors.SetTopLevelMsgf(err, "mismatched 
restriction type in method %v, "+
+                                       "parameter at index %v. got: %v, want: 
%v (from method %v). "+
+                                       "Ensure that all restrictions in an SDF 
are the same type.",
+                                       initialWatermarkEstimatorStateName, 1, 
method.Param[1].T, restT, createTrackerName)
+                       }
+                       if err := validateSdfElementT(fn, restrictionSizeName, 
method, numMainIn, 2); err != nil {
+                               return err
                        }
 
-                       switch name {
-                       case createWatermarkEstimatorName:
-                               if 
!method.Ret[0].T.Implements(watermarkEstimatorT) {
-                                       err := errors.Errorf("invalid output 
type in method %v, return %v: %v does not implement sdf.WatermarkEstimator",
-                                               createWatermarkEstimatorName, 
0, method.Ret[0].T)
-                                       return errors.SetTopLevelMsgf(err, 
"Invalid output type in method %v, "+
-                                               "return value at index %v 
(type: %v). Output of method %v must implement sdf.WatermarkEstimator.",
-                                               createWatermarkEstimatorName, 
0, method.Ret[0].T, createWatermarkEstimatorName)
-                               }
+                       if len(method.Ret) != 1 {
+                               err := errors.Errorf("unexpected number of 
elements returned in method %v. got: %v, want %v",
+                                       initialWatermarkEstimatorStateName, 
len(method.Ret), 1)
+                               return errors.SetTopLevelMsgf(err, "unexpected 
number of elements returned in method %v. "+
+                                       "got: %v, want %v. Check that the 
signature conforms to the expected signature for %v.",
+                                       initialWatermarkEstimatorStateName, 
len(method.Ret), 1, initialWatermarkEstimatorStateName)
+                       }
+                       if method.Ret[0].T != watermarkStateT {
+                               err := errors.Errorf("mismatched output type in 
method %v, return %v. got: %v, want: %v",
+                                       createWatermarkEstimatorName, 0, 
method.Ret[0].T, watermarkStateT)
+                               return errors.SetTopLevelMsgf(err, "mismatched 
output type in method %v, "+
+                                       "return value at index %v got: %v, 
want: %v (from method %v). "+
+                                       "Ensure that all watermark states in an 
SDF are the same type.",
+                                       createWatermarkEstimatorName, 0, 
method.Ret[0].T, watermarkStateT, createWatermarkEstimatorName)
+                       }
+               case watermarkEstimatorStateName:
+                       if len(method.Param) != 1 {
+                               err := errors.Errorf("unexpected number of 
params in method %v. got: %v, want %v",
+                                       watermarkEstimatorStateName, 
len(method.Param), 1)
+                               return errors.SetTopLevelMsgf(err, "unexpected 
number of parameters in method %v. "+
+                                       "got: %v, want %v. Check that the 
signature conforms to the expected signature for %v, "+
+                                       "and that elements in SDF method 
parameters match elements in %v.",
+                                       watermarkEstimatorStateName, 
len(method.Param), 1, watermarkEstimatorStateName, processElementName)
+                       }
+                       if method.Param[0].T != watermarkEstimatorT {
+                               err := errors.Errorf("mismatched watermark 
state type in method %v, return %v. got: %v, want: %v",
+                                       watermarkEstimatorStateName, 0, 
method.Param[0].T, watermarkEstimatorT)
+                               return errors.SetTopLevelMsgf(err, "mismatched 
watermark state type in method %v, "+
+                                       "return value at index %v got: %v, 
want: %v (from method %v). "+
+                                       "Ensure that all watermark states in an 
SDF are the same type.",
+                                       watermarkEstimatorStateName, 0, 
method.Param[0].T, watermarkEstimatorT, watermarkEstimatorStateName)
+                       }
+                       if len(method.Ret) != 1 {
+                               err := errors.Errorf("unexpected number of 
elements returned in method %v. got: %v, want %v",
+                                       watermarkEstimatorStateName, 
len(method.Ret), 1)
+                               return errors.SetTopLevelMsgf(err, "unexpected 
number of elements returned in method %v. "+
+                                       "got: %v, want %v. Check that the 
signature conforms to the expected signature for %v.",
+                                       watermarkEstimatorStateName, 
len(method.Ret), 1, watermarkEstimatorStateName)
+                       }
+                       if method.Ret[0].T != watermarkStateT {
+                               err := errors.Errorf("mismatched output type in 
method %v, return %v. got: %v, want: %v",
+                                       watermarkEstimatorStateName, 0, 
method.Ret[0].T, watermarkStateT)
+                               return errors.SetTopLevelMsgf(err, "mismatched 
output type in method %v, "+
+                                       "return value at index %v got: %v, 
want: %v (from method %v). "+
+                                       "Ensure that all watermark estimators 
in an SDF are the same type.",
+                                       watermarkEstimatorStateName, 0, 
method.Ret[0].T, watermarkStateT, watermarkEstimatorStateName)
                        }
                }
        }
-
        return nil
 }
 
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go 
b/sdks/go/pkg/beam/core/graph/fn_test.go
index e38c9a34af6..c04d2d07529 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -235,6 +235,9 @@ func TestNewDoFnWatermarkEstimating(t *testing.T) {
                        main mainInputs
                }{
                        {dfn: &GoodWatermarkEstimating{}, main: MainSingle},
+                       {dfn: &GoodWatermarkEstimatingKv{}, main: MainKv},
+                       {dfn: &GoodStatefulWatermarkEstimating{}, main: 
MainSingle},
+                       {dfn: &GoodStatefulWatermarkEstimatingKv{}, main: 
MainKv},
                }
 
                for _, test := range tests {
@@ -255,6 +258,17 @@ func TestNewDoFnWatermarkEstimating(t *testing.T) {
                }{
                        {dfn: &BadWatermarkEstimatingNonSdf{}},
                        {dfn: 
&BadWatermarkEstimatingCreateWatermarkEstimatorReturnType{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingInconsistentState{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingInconsistentEstimator{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateParams{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoParams{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateReturns{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoReturns{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingWrongPositionalParameter0{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingWrongPositionalParameter1{}},
+                       {dfn: 
&BadStatefulWatermarkEstimatingWrongPositionalParameter2{}},
+                       {dfn: 
&BadStatefulKvWatermarkEstimatingWrongPositionalParameter2{}},
+                       {dfn: &BadStatefulWatermarkEstimatingWrongReturn{}},
                }
                for _, test := range tests {
                        t.Run(reflect.TypeOf(test.dfn).String(), func(t 
*testing.T) {
@@ -693,6 +707,27 @@ func (rt *RTrackerT) GetRestriction() interface{} {
        return nil
 }
 
+type RTracker2T struct{}
+
+func (rt *RTracker2T) TryClaim(interface{}) bool {
+       return false
+}
+func (rt *RTracker2T) GetError() error {
+       return nil
+}
+func (rt *RTracker2T) TrySplit(fraction float64) (interface{}, interface{}, 
error) {
+       return nil, nil, nil
+}
+func (rt *RTracker2T) GetProgress() (float64, float64) {
+       return 0, 0
+}
+func (rt *RTracker2T) IsDone() bool {
+       return true
+}
+func (rt *RTracker2T) GetRestriction() interface{} {
+       return nil
+}
+
 type GoodSdf struct {
        *GoodDoFn
 }
@@ -747,34 +782,64 @@ func (e WatermarkEstimatorT) CurrentWatermark() time.Time 
{
        return time.Now()
 }
 
+type WatermarkEstimator2T struct{}
+
+func (e WatermarkEstimator2T) CurrentWatermark() time.Time {
+       return time.Now()
+}
+
+func (e WatermarkEstimator2T) CurrentWatermark2() time.Time {
+       return time.Now()
+}
+
 type GoodWatermarkEstimating struct {
-       *GoodDoFn
+       *GoodSdf
 }
 
-func (fn *GoodWatermarkEstimating) CreateInitialRestriction(int) RestT {
-       return RestT{}
+func (fn *GoodWatermarkEstimating) CreateWatermarkEstimator() 
WatermarkEstimatorT {
+       return WatermarkEstimatorT{}
 }
 
-func (fn *GoodWatermarkEstimating) SplitRestriction(int, RestT) []RestT {
-       return []RestT{}
+type GoodWatermarkEstimatingKv struct {
+       *GoodSdfKv
+}
+
+func (fn *GoodWatermarkEstimatingKv) CreateWatermarkEstimator() 
WatermarkEstimatorT {
+       return WatermarkEstimatorT{}
+}
+
+type GoodStatefulWatermarkEstimating struct {
+       *GoodSdf
 }
 
-func (fn *GoodWatermarkEstimating) RestrictionSize(int, RestT) float64 {
+func (fn *GoodStatefulWatermarkEstimating) InitialWatermarkEstimatorState(ts 
typex.EventTime, rt RestT, element int) int {
        return 0
 }
 
-func (fn *GoodWatermarkEstimating) CreateTracker(RestT) *RTrackerT {
-       return &RTrackerT{}
+func (fn *GoodStatefulWatermarkEstimating) CreateWatermarkEstimator(state int) 
WatermarkEstimatorT {
+       return WatermarkEstimatorT{}
 }
 
-func (fn *GoodWatermarkEstimating) ProcessElement(*RTrackerT, int) int {
+func (fn *GoodStatefulWatermarkEstimating) WatermarkEstimatorState(estimator 
WatermarkEstimatorT) int {
        return 0
 }
 
-func (fn *GoodWatermarkEstimating) CreateWatermarkEstimator() 
WatermarkEstimatorT {
+type GoodStatefulWatermarkEstimatingKv struct {
+       *GoodSdfKv
+}
+
+func (fn *GoodStatefulWatermarkEstimatingKv) InitialWatermarkEstimatorState(ts 
typex.EventTime, rt RestT, k int, v int) int {
+       return 0
+}
+
+func (fn *GoodStatefulWatermarkEstimatingKv) CreateWatermarkEstimator(state 
int) WatermarkEstimatorT {
        return WatermarkEstimatorT{}
 }
 
+func (fn *GoodStatefulWatermarkEstimatingKv) WatermarkEstimatorState(estimator 
WatermarkEstimatorT) int {
+       return 0
+}
+
 // Examples of incorrect SDF signatures.
 // Examples with missing methods.
 
@@ -972,6 +1037,93 @@ func (fn 
*BadWatermarkEstimatingCreateWatermarkEstimatorReturnType) CreateWaterm
        return 5
 }
 
+type BadStatefulWatermarkEstimatingInconsistentState struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingInconsistentState) 
WatermarkEstimatorState(estimator WatermarkEstimatorT) string {
+       return ""
+}
+
+type BadStatefulWatermarkEstimatingInconsistentEstimator struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingInconsistentEstimator) 
WatermarkEstimatorState(estimator WatermarkEstimator2T) int {
+       return 0
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateParams struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateParams) 
WatermarkEstimatorState(estimator WatermarkEstimatorT, element int) int {
+       return 0
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoParams struct 
{
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoParams) 
WatermarkEstimatorState() int {
+       return 0
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateReturns struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateReturns) 
WatermarkEstimatorState(estimator WatermarkEstimatorT) (int, error) {
+       return 0, nil
+}
+
+type BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoReturns 
struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingExtraWatermarkEstimatorStateNoReturns) 
WatermarkEstimatorState(estimator WatermarkEstimatorT) {
+}
+
+type BadStatefulWatermarkEstimatingWrongPositionalParameter0 struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongPositionalParameter0) 
InitialWatermarkEstimatorState(a int, rt *RTrackerT, element int) int {
+       return 0
+}
+
+type BadStatefulWatermarkEstimatingWrongPositionalParameter1 struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongPositionalParameter1) 
InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTracker2T, element int) 
int {
+       return 0
+}
+
+type BadStatefulWatermarkEstimatingWrongPositionalParameter2 struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongPositionalParameter2) 
InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTrackerT, element 
string) int {
+       return 0
+}
+
+type BadStatefulKvWatermarkEstimatingWrongPositionalParameter2 struct {
+       *GoodStatefulWatermarkEstimatingKv
+}
+
+func (fn *BadStatefulKvWatermarkEstimatingWrongPositionalParameter2) 
InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTrackerT, element int) 
int {
+       return 0
+}
+
+type BadStatefulWatermarkEstimatingWrongReturn struct {
+       *GoodStatefulWatermarkEstimating
+}
+
+func (fn *BadStatefulWatermarkEstimatingWrongReturn) 
InitialWatermarkEstimatorState(ts typex.EventTime, rt *RTrackerT, element int) 
string {
+       return ""
+}
+
 // Examples of correct CombineFn signatures
 
 type MyAccum struct{}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go
index 64355a84c1f..bd2835fe2ff 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go
@@ -125,7 +125,7 @@ func TestDynamicSplit(t *testing.T) {
                        if err := procRes; err != nil {
                                t.Fatal(err)
                        }
-                       pRest := 
p.Elm.(*FullValue).Elm2.(offsetrange.Restriction)
+                       pRest := 
p.Elm.(*FullValue).Elm2.(*FullValue).Elm.(offsetrange.Restriction)
                        if got, want := len(out.Elements), 
int(pRest.End-pRest.Start); got != want {
                                t.Errorf("Unexpected number of elements: got: 
%v, want: %v", got, want)
                        }
@@ -226,8 +226,11 @@ func claimBlockingDriver(plan *Plan, dc DataContext, sdf 
*splitTestSdf) (splitRe
 func createElm() *FullValue {
        return &FullValue{
                Elm: &FullValue{
-                       Elm:  20,
-                       Elm2: offsetrange.Restriction{Start: 0, End: 20},
+                       Elm: 20,
+                       Elm2: &FullValue{
+                               Elm:  offsetrange.Restriction{Start: 0, End: 
20},
+                               Elm2: false,
+                       },
                },
                Elm2: float64(20),
        }
@@ -244,7 +247,10 @@ func createSplitTestInCoder() *coder.Coder {
                coder.NewKV([]*coder.Coder{
                        coder.NewKV([]*coder.Coder{
                                intCoder(reflectx.Int),
-                               {Kind: coder.Custom, T: typex.New(restT), 
Custom: restCdr},
+                               coder.NewKV([]*coder.Coder{
+                                       {Kind: coder.Custom, T: 
typex.New(restT), Custom: restCdr},
+                                       coder.NewBool(),
+                               }),
                        }),
                        coder.NewDouble(),
                }),
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index cd1234c29e3..ec457213e86 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -21,6 +21,7 @@ import (
        "math"
        "path"
 
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
@@ -37,7 +38,8 @@ type PairWithRestriction struct {
        Fn  *graph.DoFn
        Out Node
 
-       inv *cirInvoker
+       inv     *cirInvoker
+       iwesInv *iwesInvoker
 }
 
 // ID returns the UnitID for this unit.
@@ -52,6 +54,13 @@ func (n *PairWithRestriction) Up(_ context.Context) error {
        if n.inv, err = newCreateInitialRestrictionInvoker(fn); err != nil {
                return errors.WithContextf(err, "%v", n)
        }
+       var giwesFn *funcx.Fn
+       if (*graph.SplittableDoFn)(n.Fn).IsStatefulWatermarkEstimating() {
+               giwesFn = 
(*graph.SplittableDoFn)(n.Fn).InitialWatermarkEstimatorStateFn()
+       }
+       if n.iwesInv, err = newInitialWatermarkEstimatorStateInvoker(giwesFn); 
err != nil {
+               return errors.WithContextf(err, "%v", n)
+       }
        return nil
 }
 
@@ -73,13 +82,16 @@ func (n *PairWithRestriction) StartBundle(ctx 
context.Context, id string, data D
 //
 //   *FullValue {
 //     Elm: *FullValue (original input)
-//     Elm2: Restriction
+//     Elm2: *FullValue {
+//       Elm: Restriction
+//       Elm2: Watermark estimator state
+//     }
 //     Windows
 //     Timestamps
 //   }
 func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm 
*FullValue, values ...ReStream) error {
        rest := n.inv.Invoke(elm)
-       output := FullValue{Elm: elm, Elm2: rest, Timestamp: elm.Timestamp, 
Windows: elm.Windows}
+       output := FullValue{Elm: elm, Elm2: &FullValue{Elm: rest, Elm2: 
n.iwesInv.Invoke(rest, elm)}, Timestamp: elm.Timestamp, Windows: elm.Windows}
 
        return n.Out.ProcessElement(ctx, &output, values...)
 }
@@ -87,6 +99,7 @@ func (n *PairWithRestriction) ProcessElement(ctx 
context.Context, elm *FullValue
 // FinishBundle resets the invokers.
 func (n *PairWithRestriction) FinishBundle(ctx context.Context) error {
        n.inv.Reset()
+       n.iwesInv.Reset()
        return n.Out.FinishBundle(ctx)
 }
 
@@ -147,13 +160,16 @@ func (n *SplitAndSizeRestrictions) StartBundle(ctx 
context.Context, id string, d
 //
 //   *FullValue {
 //     Elm: *FullValue (original input)
-//     Elm2: Restriction
+//     Elm2: *FullValue {
+//       Elm: Restriction
+//       Elm2: Watermark estimator state
+//     }
 //     Windows
 //     Timestamps
 //   }
 //
 // ProcessElement splits the given restriction into one or more restrictions 
and
-// then sizes each. The outputs are in the structure <<elem, restriction>, 
size>
+// then sizes each. The outputs are in the structure <<elem, <restriction, 
watermark estimator state>>, size>
 // where elem is the original main input to the unexpanded SDF. Windows and
 // Timestamps are copied to each split output.
 //
@@ -162,14 +178,18 @@ func (n *SplitAndSizeRestrictions) StartBundle(ctx 
context.Context, id string, d
 //   *FullValue {
 //     Elm: *FullValue {
 //       Elm:  *FullValue (original input)
-//       Elm2: Restriction
+//       Elm2: *FullValue {
+//                Elm: Restriction
+//         Elm2: Watermark estimator state
+//              }
 //     }
 //     Elm2: float64 (size)
 //     Windows
 //     Timestamps
 //   }
 func (n *SplitAndSizeRestrictions) ProcessElement(ctx context.Context, elm 
*FullValue, values ...ReStream) error {
-       rest := elm.Elm2
+       rest := elm.Elm2.(*FullValue).Elm
+       ws := elm.Elm2.(*FullValue).Elm2
        mainElm := elm.Elm.(*FullValue)
 
        splitRests := n.splitInv.Invoke(mainElm, rest)
@@ -184,7 +204,7 @@ func (n *SplitAndSizeRestrictions) ProcessElement(ctx 
context.Context, elm *Full
 
                output.Timestamp = elm.Timestamp
                output.Windows = elm.Windows
-               output.Elm = &FullValue{Elm: mainElm, Elm2: splitRest}
+               output.Elm = &FullValue{Elm: mainElm, Elm2: &FullValue{Elm: 
splitRest, Elm2: ws}}
                output.Elm2 = size
 
                if err := n.Out.ProcessElement(ctx, output, values...); err != 
nil {
@@ -223,6 +243,7 @@ type ProcessSizedElementsAndRestrictions struct {
        ctInv   *ctInvoker
        sizeInv *rsInvoker
        cweInv  *cweInvoker
+       wesInv  *wesInvoker
 
        // SU is a buffered channel for indicating when this unit is splittable.
        // When this unit is processing an element, it sends a SplittableUnit
@@ -242,9 +263,10 @@ type ProcessSizedElementsAndRestrictions struct {
        // from a DoFn for use in splitting the bundle if the process should be 
resumed.
        continuation sdf.ProcessContinuation
 
-       elm   *FullValue   // Currently processing element.
-       rt    sdf.RTracker // Currently processing element's restriction 
tracker.
-       currW int          // Index of the current window in elm being 
processed.
+       elm     *FullValue   // Currently processing element.
+       rt      sdf.RTracker // Currently processing element's restriction 
tracker.
+       currW   int          // Index of the current window in elm being 
processed.
+       initWeS interface{}  // Initial state of the watermark estimator before 
processing elements.
 
        // Number of windows being processed. This number can differ from the 
number
        // of windows in an element, indicating to only process a subset of 
windows.
@@ -278,6 +300,13 @@ func (n *ProcessSizedElementsAndRestrictions) Up(ctx 
context.Context) error {
                        return errors.WithContextf(err, "%v", n)
                }
        }
+       var gwesFn *funcx.Fn
+       if (*graph.SplittableDoFn)(n.PDo.Fn).IsStatefulWatermarkEstimating() {
+               gwesFn = 
(*graph.SplittableDoFn)(n.PDo.Fn).WatermarkEstimatorStateFn()
+       }
+       if n.wesInv, err = newWatermarkEstimatorStateInvoker(gwesFn); err != 
nil {
+               return errors.WithContextf(err, "%v", n)
+       }
        n.SU = make(chan SplittableUnit, 1)
        return n.PDo.Up(ctx)
 }
@@ -288,7 +317,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
StartBundle(ctx context.Context, i
 }
 
 // ProcessElement expects the same structure as the output of
-// SplitAndSizeRestrictions, approximately <<elem, restriction>, size>. The
+// SplitAndSizeRestrictions, approximately <<elem, <restriction,watermark 
estimator state>>, size>. The
 // only difference is that if the input was decoded in between the two steps,
 // then single-element inputs were lifted from the *FullValue they were
 // stored in.
@@ -298,7 +327,10 @@ func (n *ProcessSizedElementsAndRestrictions) 
StartBundle(ctx context.Context, i
 //   *FullValue {
 //     Elm: *FullValue {
 //       Elm:  *FullValue (KV input) or InputType (single-element input)
-//       Elm2: Restriction
+//              Elm2: *FullValue {
+//                Elm: Restriction
+//         Elm2: Watermark estimator state
+//              }
 //     }
 //     Elm2: float64 (size)
 //     Windows
@@ -343,15 +375,16 @@ func (n *ProcessSizedElementsAndRestrictions) 
ProcessElement(_ context.Context,
        }
 
        if n.cweInv != nil {
-               n.PDo.we = n.cweInv.Invoke()
+               n.PDo.we = 
n.cweInv.Invoke(elm.Elm.(*FullValue).Elm2.(*FullValue).Elm2)
        }
+       n.initWeS = n.wesInv.Invoke(n.PDo.we)
 
        // Begin processing elements, exploding windows if necessary.
        n.currW = 0
        if !mustExplodeWindows(n.PDo.inv.fn, elm, len(n.PDo.Side) > 0) {
                // If windows don't need to be exploded (i.e. aren't observed), 
treat
                // all windows as one as an optimization.
-               rest := elm.Elm.(*FullValue).Elm2
+               rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
                rt := n.ctInv.Invoke(rest)
                mainIn.RTracker = rt
 
@@ -373,7 +406,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
ProcessElement(_ context.Context,
                n.numW = len(elm.Windows)
 
                for i := 0; i < n.numW; i++ {
-                       rest := elm.Elm.(*FullValue).Elm2
+                       rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
                        rt := n.ctInv.Invoke(rest)
                        key := &mainIn.Key
                        w := elm.Windows[i]
@@ -402,6 +435,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
FinishBundle(ctx context.Context)
        if n.cweInv != nil {
                n.cweInv.Reset()
        }
+       n.wesInv.Reset()
        return n.PDo.FinishBundle(ctx)
 }
 
@@ -457,6 +491,17 @@ type SplittableUnit interface {
 // each case occurs and the implementation details, see the documentation for
 // the singleWindowSplit and multiWindowSplit methods.
 func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, 
[]*FullValue, error) {
+       // Get the watermark state immediately so that we don't overestimate 
our current watermark.
+       var pWeState interface{}
+       var rWeState interface{}
+       rWeState = n.wesInv.Invoke(n.PDo.we)
+       pWeState = rWeState
+       // If we've processed elements, the initial watermark estimator state 
will be set.
+       // In that case we should hold the output watermark at that initial 
state so that we don't
+       // Advance past where the current elements are holding the watermark
+       if n.initWeS != nil {
+               pWeState = n.initWeS
+       }
        addContext := func(err error) error {
                return errors.WithContext(err, "Attempting split in 
ProcessSizedElementsAndRestrictions")
        }
@@ -472,7 +517,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f 
float64) ([]*FullValue, []
        // Split behavior differs depending on whether this is a 
window-observing
        // DoFn or not.
        if len(n.elm.Windows) > 1 {
-               p, r, err := n.multiWindowSplit(f)
+               p, r, err := n.multiWindowSplit(f, pWeState, rWeState)
                if err != nil {
                        return nil, nil, addContext(err)
                }
@@ -480,7 +525,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f 
float64) ([]*FullValue, []
        }
 
        // Not window-observing, or window-observing but only one window.
-       p, r, err := n.singleWindowSplit(f)
+       p, r, err := n.singleWindowSplit(f, pWeState, rWeState)
        if err != nil {
                return nil, nil, addContext(err)
        }
@@ -492,7 +537,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f 
float64) ([]*FullValue, []
 // behavior is identical). A single restriction split will occur and all 
windows
 // present in the unsplit element will be present in both the resulting primary
 // and residual.
-func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64) 
([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, 
pWeState, rWeState interface{}) ([]*FullValue, []*FullValue, error) {
        if n.rt.IsDone() { // Not an error, but not splittable.
                return []*FullValue{}, []*FullValue{}, nil
        }
@@ -505,11 +550,11 @@ func (n *ProcessSizedElementsAndRestrictions) 
singleWindowSplit(f float64) ([]*F
                return []*FullValue{}, []*FullValue{}, nil
        }
 
-       pfv, err := n.newSplitResult(p, n.elm.Windows)
+       pfv, err := n.newSplitResult(p, n.elm.Windows, pWeState)
        if err != nil {
                return nil, nil, err
        }
-       rfv, err := n.newSplitResult(r, n.elm.Windows)
+       rfv, err := n.newSplitResult(r, n.elm.Windows, rWeState)
        if err != nil {
                return nil, nil, err
        }
@@ -540,7 +585,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
singleWindowSplit(f float64) ([]*F
 //
 // This method also updates the current number of windows (n.numW) so that
 // windows in the residual will no longer be processed.
-func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64) 
([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, 
pWeState interface{}, rWeState interface{}) ([]*FullValue, []*FullValue, error) 
{
        // Get the split point in window range, to see what window it falls in.
        done, rem := n.rt.GetProgress()
        cwp := done / (done + rem)                      // Progress in current 
window.
@@ -553,25 +598,25 @@ func (n *ProcessSizedElementsAndRestrictions) 
multiWindowSplit(f float64) ([]*Fu
                if n.rt.IsDone() {
                        // Current RTracker is done so we can't split within 
the window, so
                        // split at window boundary instead.
-                       return n.windowBoundarySplit(n.currW + 1)
+                       return n.windowBoundarySplit(n.currW+1, pWeState, 
rWeState)
                }
 
                // Get the fraction of remaining work in the current window to 
split at.
                cwsp := wsp - float64(n.currW) // Split point in current window.
                rf := (cwsp - cwp) / (1 - cwp) // Fraction of work in RTracker 
to split at.
 
-               return n.currentWindowSplit(rf)
+               return n.currentWindowSplit(rf, pWeState, rWeState)
        } else {
                // Split at nearest window boundary to split point.
                wb := math.Round(wsp)
-               return n.windowBoundarySplit(int(wb))
+               return n.windowBoundarySplit(int(wb), pWeState, rWeState)
        }
 }
 
 // currentWindowSplit performs an appropriate split at the given fraction of
 // remaining work in the current window. Also updates numW to stop after the
 // current window.
-func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64) 
([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, 
pWeState interface{}, rWeState interface{}) ([]*FullValue, []*FullValue, error) 
{
        p, r, err := n.rt.TrySplit(f)
        if err != nil {
                return nil, nil, err
@@ -579,33 +624,33 @@ func (n *ProcessSizedElementsAndRestrictions) 
currentWindowSplit(f float64) ([]*
        if r == nil {
                // If r is nil then the split failed/returned an empty 
residual, but
                // we can still split at a window boundary.
-               return n.windowBoundarySplit(n.currW + 1)
+               return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
        }
 
        // Split of currently processing restriction in a single window.
        ps := make([]*FullValue, 1)
-       newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1])
+       newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1], 
pWeState)
        if err != nil {
                return nil, nil, err
        }
        ps[0] = newP
        rs := make([]*FullValue, 1)
-       newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1])
+       newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1], 
rWeState)
        if err != nil {
                return nil, nil, err
        }
        rs[0] = newR
        // Window boundary split surrounding the split restriction above.
-       full := n.elm.Elm.(*FullValue).Elm2
+       full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
        if 0 < n.currW {
-               newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW])
+               newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW], 
pWeState)
                if err != nil {
                        return nil, nil, err
                }
                ps = append(ps, newP)
        }
        if n.currW+1 < n.numW {
-               newR, err := n.newSplitResult(full, 
n.elm.Windows[n.currW+1:n.numW])
+               newR, err := n.newSplitResult(full, 
n.elm.Windows[n.currW+1:n.numW], rWeState)
                if err != nil {
                        return nil, nil, err
                }
@@ -618,17 +663,17 @@ func (n *ProcessSizedElementsAndRestrictions) 
currentWindowSplit(f float64) ([]*
 // windowBoundarySplit performs an appropriate split at a window boundary. The
 // split point taken should be the index of the first window in the residual.
 // Also updates numW to stop at the split point.
-func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int) 
([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, 
pWeState interface{}, rWeState interface{}) ([]*FullValue, []*FullValue, error) 
{
        // If this is at the boundary of the last window, split is a no-op.
        if splitPt == n.numW {
                return []*FullValue{}, []*FullValue{}, nil
        }
-       full := n.elm.Elm.(*FullValue).Elm2
-       pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt])
+       full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
+       pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt], pWeState)
        if err != nil {
                return nil, nil, err
        }
-       rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW])
+       rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW], 
rWeState)
        if err != nil {
                return nil, nil, err
        }
@@ -640,7 +685,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
windowBoundarySplit(splitPt int) (
 // element restriction pair based on the currently processing element, but with
 // a modified restriction and windows. Intended for creating primaries and
 // residuals to return as split results.
-func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest interface{}, 
w []typex.Window) (*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest interface{}, 
w []typex.Window, weState interface{}) (*FullValue, error) {
        var size float64
        elm := n.elm.Elm.(*FullValue).Elm
        if fv, ok := elm.(*FullValue); ok {
@@ -659,8 +704,11 @@ func (n *ProcessSizedElementsAndRestrictions) 
newSplitResult(rest interface{}, w
        }
        return &FullValue{
                Elm: &FullValue{
-                       Elm:  elm,
-                       Elm2: rest,
+                       Elm: elm,
+                       Elm2: &FullValue{
+                               Elm:  rest,
+                               Elm2: weState,
+                       },
                },
                Elm2:      size,
                Timestamp: n.elm.Timestamp,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
index 79eeda23afc..f90cdd0abf5 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
@@ -305,7 +305,7 @@ func (n *ctInvoker) Reset() {
 type cweInvoker struct {
        fn   *funcx.Fn
        args []interface{} // Cache to avoid allocating new slices per-element.
-       call func() sdf.WatermarkEstimator
+       call func(rest interface{}) sdf.WatermarkEstimator
 }
 
 func newCreateWatermarkEstimatorInvoker(fn *funcx.Fn) (*cweInvoker, error) {
@@ -321,27 +321,38 @@ func newCreateWatermarkEstimatorInvoker(fn *funcx.Fn) 
(*cweInvoker, error) {
 
 func (n *cweInvoker) initCallFn() error {
        // Expects a signature of the form:
-       // () sdf.WatermarkEstimator
+       // (watermarkState?) sdf.WatermarkEstimator
        switch fnT := n.fn.Fn.(type) {
        case reflectx.Func0x1:
-               n.call = func() sdf.WatermarkEstimator {
+               n.call = func(rest interface{}) sdf.WatermarkEstimator {
                        return fnT.Call0x1().(sdf.WatermarkEstimator)
                }
+       case reflectx.Func1x1:
+               n.call = func(rest interface{}) sdf.WatermarkEstimator {
+                       return fnT.Call1x1(rest).(sdf.WatermarkEstimator)
+               }
        default:
-               if len(n.fn.Param) != 0 {
+               switch len(n.fn.Param) {
+               case 0:
+                       n.call = func(rest interface{}) sdf.WatermarkEstimator {
+                               return 
n.fn.Fn.Call(n.args)[0].(sdf.WatermarkEstimator)
+                       }
+               case 1:
+                       n.call = func(rest interface{}) sdf.WatermarkEstimator {
+                               n.args[0] = rest
+                               return 
n.fn.Fn.Call(n.args)[0].(sdf.WatermarkEstimator)
+                       }
+               default:
                        return errors.Errorf("CreateWatermarkEstimator fn %v 
has unexpected number of parameters: %v",
                                n.fn.Fn.Name(), len(n.fn.Param))
                }
-               n.call = func() sdf.WatermarkEstimator {
-                       return n.fn.Fn.Call(n.args)[0].(sdf.WatermarkEstimator)
-               }
        }
        return nil
 }
 
 // Invoke calls CreateWatermarkEstimator given a restriction and returns an 
sdf.WatermarkEstimator.
-func (n *cweInvoker) Invoke() sdf.WatermarkEstimator {
-       return n.call()
+func (n *cweInvoker) Invoke(rest interface{}) sdf.WatermarkEstimator {
+       return n.call(rest)
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -351,3 +362,147 @@ func (n *cweInvoker) Reset() {
                n.args[i] = nil
        }
 }
+
+// iwesInvoker is an invoker for InitialWatermarkEstimatorState.
+type iwesInvoker struct {
+       fn   *funcx.Fn
+       args []interface{} // Cache to avoid allocating new slices per-element.
+       call func(rest interface{}, elms *FullValue) interface{}
+}
+
+func newInitialWatermarkEstimatorStateInvoker(fn *funcx.Fn) (*iwesInvoker, 
error) {
+       args := []interface{}{}
+       if fn != nil {
+               args = make([]interface{}, len(fn.Param))
+       }
+       n := &iwesInvoker{
+               fn:   fn,
+               args: args,
+       }
+       if err := n.initCallFn(); err != nil {
+               return nil, errors.WithContext(err, "sdf 
InitialWatermarkEstimatorState invoker")
+       }
+       return n, nil
+}
+
+func (n *iwesInvoker) initCallFn() error {
+       // If no WatermarkEstimatorState function is defined, we'll use a 
default implementation that just returns false as the state.
+       if n.fn == nil {
+               n.call = func(rest interface{}, elms *FullValue) interface{} {
+                       return false
+               }
+               return nil
+       }
+       // Expects a signature of the form:
+       // (typex.EventTime, restrictionTracker, key?, value) interface{}
+       switch fnT := n.fn.Fn.(type) {
+       case reflectx.Func3x1:
+               n.call = func(rest interface{}, elms *FullValue) interface{} {
+                       return fnT.Call3x1(elms.Timestamp, rest, elms.Elm)
+               }
+       case reflectx.Func4x1:
+               n.call = func(rest interface{}, elms *FullValue) interface{} {
+                       return fnT.Call4x1(elms.Timestamp, rest, elms.Elm, 
elms.Elm2)
+               }
+       default:
+               switch len(n.fn.Param) {
+               case 3:
+                       n.call = func(rest interface{}, elms *FullValue) 
interface{} {
+                               n.args[0] = elms.Timestamp
+                               n.args[1] = rest
+                               n.args[2] = elms.Elm
+                               return n.fn.Fn.Call(n.args)[0]
+                       }
+               case 4:
+                       n.call = func(rest interface{}, elms *FullValue) 
interface{} {
+                               n.args[0] = elms.Timestamp
+                               n.args[1] = rest
+                               n.args[2] = elms.Elm
+                               n.args[3] = elms.Elm2
+                               return n.fn.Fn.Call(n.args)[0]
+                       }
+               default:
+                       return errors.Errorf("InitialWatermarkEstimatorState fn 
%v has unexpected number of parameters: %v",
+                               n.fn.Fn.Name(), len(n.fn.Param))
+               }
+       }
+       return nil
+}
+
+// Invoke calls InitialWatermarkEstimatorState given a restriction and returns 
an sdf.RTracker.
+func (n *iwesInvoker) Invoke(rest interface{}, elms *FullValue) interface{} {
+       return n.call(rest, elms)
+}
+
+// Reset zeroes argument entries in the cached slice to allow values to be
+// garbage collected after the bundle ends.
+func (n *iwesInvoker) Reset() {
+       for i := range n.args {
+               n.args[i] = nil
+       }
+}
+
+// wesInvoker is an invoker for WatermarkEstimatorState.
+type wesInvoker struct {
+       fn   *funcx.Fn
+       args []interface{} // Cache to avoid allocating new slices per-element.
+       call func(we sdf.WatermarkEstimator) interface{}
+}
+
+func newWatermarkEstimatorStateInvoker(fn *funcx.Fn) (*wesInvoker, error) {
+       args := []interface{}{}
+       if fn != nil {
+               args = make([]interface{}, len(fn.Param))
+       }
+       n := &wesInvoker{
+               fn:   fn,
+               args: args,
+       }
+       if err := n.initCallFn(); err != nil {
+               return nil, errors.WithContext(err, "sdf 
WatermarkEstimatorState invoker")
+       }
+       return n, nil
+}
+
+func (n *wesInvoker) initCallFn() error {
+       // If no WatermarkEstimatorState function is defined, we'll use a 
default implementation that just returns false as the state.
+       if n.fn == nil {
+               n.call = func(we sdf.WatermarkEstimator) interface{} {
+                       return false
+               }
+               return nil
+       }
+       // Expects a signature of the form:
+       // (state) sdf.WatermarkEstimator
+       switch fnT := n.fn.Fn.(type) {
+       case reflectx.Func1x1:
+               n.call = func(we sdf.WatermarkEstimator) interface{} {
+                       return fnT.Call1x1(we)
+               }
+       default:
+               switch len(n.fn.Param) {
+               case 1:
+                       n.call = func(we sdf.WatermarkEstimator) interface{} {
+                               n.args[0] = we
+                               return n.fn.Fn.Call(n.args)[0]
+                       }
+               default:
+                       return errors.Errorf("WatermarkEstimatorState fn %v has 
unexpected number of parameters: %v",
+                               n.fn.Fn.Name(), len(n.fn.Param))
+               }
+       }
+       return nil
+}
+
+// Invoke calls WatermarkEstimatorState given a restriction and returns an 
sdf.RTracker.
+func (n *wesInvoker) Invoke(we sdf.WatermarkEstimator) interface{} {
+       return n.call(we)
+}
+
+// Reset zeroes argument entries in the cached slice to allow values to be
+// garbage collected after the bundle ends.
+func (n *wesInvoker) Reset() {
+       for i := range n.args {
+               n.args[i] = nil
+       }
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
index 10f5f899384..bf959d8ee01 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
@@ -20,6 +20,8 @@ import (
        "time"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
        "github.com/google/go-cmp/cmp"
 )
 
@@ -39,6 +41,12 @@ func TestInvokes(t *testing.T) {
        }
        kvsdf := (*graph.SplittableDoFn)(dfn)
 
+       dfn, err = graph.NewDoFn(&VetSdfStatefulWatermark{}, 
graph.NumMainInputs(graph.MainSingle))
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+       statefulWeFn := (*graph.SplittableDoFn)(dfn)
+
        // Tests.
        t.Run("CreateInitialRestriction Invoker (cirInvoker)", func(t 
*testing.T) {
                tests := []struct {
@@ -231,21 +239,71 @@ func TestInvokes(t *testing.T) {
        })
 
        t.Run("CreateWatermarkEstimator Invoker (cweInvoker)", func(t 
*testing.T) {
-               fn := sdf.CreateWatermarkEstimatorFn()
-               invoker, err := newCreateWatermarkEstimatorInvoker(fn)
+               tests := []struct {
+                       name  string
+                       sdf   *graph.SplittableDoFn
+                       state int
+                       want  VetWatermarkEstimator
+               }{
+                       {
+                               name:  "Non-stateful",
+                               sdf:   sdf,
+                               state: 1,
+                               want:  VetWatermarkEstimator{State: -1},
+                       }, {
+                               name:  "Stateful",
+                               sdf:   statefulWeFn,
+                               state: 11,
+                               want:  VetWatermarkEstimator{State: 11},
+                       },
+               }
+
+               for _, test := range tests {
+                       test := test
+                       fn := test.sdf.CreateWatermarkEstimatorFn()
+                       t.Run(test.name, func(t *testing.T) {
+                               invoker, err := 
newCreateWatermarkEstimatorInvoker(fn)
+                               if err != nil {
+                                       
t.Fatalf("newCreateWatermarkEstimatorInvoker failed: %v", err)
+                               }
+                               got := invoker.Invoke(test.state)
+                               want := &test.want
+                               if !cmp.Equal(got, want) {
+                                       t.Errorf("Invoke() has incorrect 
output: got: %v, want: %v", got, want)
+                               }
+                               invoker.Reset()
+                               for i, arg := range invoker.args {
+                                       if arg != nil {
+                                               t.Errorf("Reset() failed to 
empty all args. args[%v] = %v", i, arg)
+                                       }
+                               }
+                       })
+               }
+       })
+
+       t.Run("InitialWatermarkEstimatorState Invoker (iwesInvoker)", func(t 
*testing.T) {
+               fn := statefulWeFn.InitialWatermarkEstimatorStateFn()
+               invoker, err := newInitialWatermarkEstimatorStateInvoker(fn)
                if err != nil {
-                       t.Fatalf("newCreateWatermarkEstimatorInvoker failed: 
%v", err)
+                       t.Fatalf("newInitialWatermarkEstimatorStateInvoker 
failed: %v", err)
                }
-               got := invoker.Invoke()
-               want := &VetWatermarkEstimator{}
-               if !cmp.Equal(got, want) {
+               got := invoker.Invoke(&VetRestriction{ID: "Sdf"}, 
&FullValue{Elm: 1, Timestamp: mtime.ZeroTimestamp})
+               want := 1
+               if got != want {
                        t.Errorf("Invoke() has incorrect output: got: %v, want: 
%v", got, want)
                }
-               invoker.Reset()
-               for i, arg := range invoker.args {
-                       if arg != nil {
-                               t.Errorf("Reset() failed to empty all args. 
args[%v] = %v", i, arg)
-                       }
+       })
+
+       t.Run("WatermarkEstimatorState Invoker (wesInvoker)", func(t 
*testing.T) {
+               fn := statefulWeFn.WatermarkEstimatorStateFn()
+               invoker, err := newWatermarkEstimatorStateInvoker(fn)
+               if err != nil {
+                       t.Fatalf("newWatermarkEstimatorStateInvoker failed: 
%v", err)
+               }
+               got := invoker.Invoke(&VetWatermarkEstimator{State: 11})
+               want := 11
+               if got != want {
+                       t.Errorf("Invoke() has incorrect output: got: %v, want: 
%v", got, want)
                }
        })
 }
@@ -288,7 +346,9 @@ func (rt *VetRTracker) TrySplit(_ float64) (interface{}, 
interface{}, error) {
        return nil, nil, nil
 }
 
-type VetWatermarkEstimator struct{}
+type VetWatermarkEstimator struct {
+       State int
+}
 
 func (e *VetWatermarkEstimator) CurrentWatermark() time.Time {
        return time.Date(2022, time.January, 1, 1, 0, 0, 0, time.UTC)
@@ -340,7 +400,7 @@ func (fn *VetSdf) CreateTracker(rest *VetRestriction) 
*VetRTracker {
 
 // CreateWatermarkEstimator creates a watermark estimator to be used by the Sdf
 func (fn *VetSdf) CreateWatermarkEstimator() *VetWatermarkEstimator {
-       return &VetWatermarkEstimator{}
+       return &VetWatermarkEstimator{State: -1}
 }
 
 // ProcessElement emits the restriction from the restriction tracker it
@@ -356,6 +416,57 @@ func (fn *VetSdf) ProcessElement(rt *VetRTracker, i int, 
emit func(*VetRestricti
        emit(rest)
 }
 
+type VetSdfStatefulWatermark struct {
+}
+
+func (fn *VetSdfStatefulWatermark) CreateInitialRestriction(i int) 
*VetRestriction {
+       return &VetRestriction{ID: "Sdf", Val: i, CreateRest: true}
+}
+
+func (fn *VetSdfStatefulWatermark) SplitRestriction(i int, rest 
*VetRestriction) []*VetRestriction {
+       rest.SplitRest = true
+       rest.Val = i
+
+       rest1 := rest.copy()
+       rest1.ID += ".1"
+       rest2 := rest.copy()
+       rest2.ID += ".2"
+
+       return []*VetRestriction{&rest1, &rest2}
+}
+
+func (fn *VetSdfStatefulWatermark) RestrictionSize(i int, rest 
*VetRestriction) float64 {
+       rest.Key = nil
+       rest.Val = i
+       rest.RestSize = true
+       return (float64)(i)
+}
+
+func (fn *VetSdfStatefulWatermark) CreateTracker(rest *VetRestriction) 
*VetRTracker {
+       rest.CreateTracker = true
+       return &VetRTracker{rest}
+}
+
+func (fn *VetSdfStatefulWatermark) InitialWatermarkEstimatorState(_ 
typex.EventTime, _ *VetRestriction, element int) int {
+       return 1
+}
+
+func (fn *VetSdfStatefulWatermark) CreateWatermarkEstimator(state int) 
*VetWatermarkEstimator {
+       return &VetWatermarkEstimator{State: state}
+}
+
+func (fn *VetSdfStatefulWatermark) WatermarkEstimatorState(e 
*VetWatermarkEstimator) int {
+       return e.State
+}
+
+func (fn *VetSdfStatefulWatermark) ProcessElement(rt *VetRTracker, i int, emit 
func(*VetRestriction)) {
+       rest := rt.Rest
+       rest.Key = nil
+       rest.Val = i
+       rest.ProcessElm = true
+       emit(rest)
+}
+
 // VetKvSdf runs an SDF In order to test that these methods get called 
properly,
 // each method will flip the corresponding flag in the passed in 
VetRestriction,
 // overwrite the restriction's Key and Val with the last seen input elements,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index e5729bfb031..1f9c56bdbac 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -64,6 +64,10 @@ func TestSdfNodes(t *testing.T) {
        if err != nil {
                t.Fatalf("invalid function: %v", err)
        }
+       statefulWeFn, err := graph.NewDoFn(&VetSdfStatefulWatermark{}, 
graph.NumMainInputs(graph.MainSingle))
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
 
        // Validate PairWithRestriction matches its contract and properly 
invokes
        // SDF method CreateInitialRestriction.
@@ -90,7 +94,34 @@ func TestSdfNodes(t *testing.T) {
                                                Timestamp: testTimestamp,
                                                Windows:   testWindows,
                                        },
-                                       Elm2:      &VetRestriction{ID: "Sdf", 
CreateRest: true, Val: 1},
+                                       Elm2: &FullValue{
+                                               Elm:  &VetRestriction{ID: 
"Sdf", CreateRest: true, Val: 1},
+                                               Elm2: false,
+                                       },
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               },
+                       },
+                       {
+                               name: "SingleElemStatefulWatermarkEstimating",
+                               fn:   statefulWeFn,
+                               in: FullValue{
+                                       Elm:       1,
+                                       Elm2:      nil,
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               },
+                               want: FullValue{
+                                       Elm: &FullValue{
+                                               Elm:       1,
+                                               Elm2:      nil,
+                                               Timestamp: testTimestamp,
+                                               Windows:   testWindows,
+                                       },
+                                       Elm2: &FullValue{
+                                               Elm:  &VetRestriction{ID: 
"Sdf", CreateRest: true, Val: 1},
+                                               Elm2: 1,
+                                       },
                                        Timestamp: testTimestamp,
                                        Windows:   testWindows,
                                },
@@ -111,7 +142,10 @@ func TestSdfNodes(t *testing.T) {
                                                Timestamp: testTimestamp,
                                                Windows:   testWindows,
                                        },
-                                       Elm2:      &VetRestriction{ID: "KvSdf", 
CreateRest: true, Key: 1, Val: 2},
+                                       Elm2: &FullValue{
+                                               Elm:  &VetRestriction{ID: 
"KvSdf", CreateRest: true, Key: 1, Val: 2},
+                                               Elm2: false,
+                                       },
                                        Timestamp: testTimestamp,
                                        Windows:   testWindows,
                                },
@@ -154,7 +188,10 @@ func TestSdfNodes(t *testing.T) {
                                                Timestamp: testTimestamp,
                                                Windows:   testWindows,
                                        },
-                                       Elm2:      &VetRestriction{ID: "Sdf"},
+                                       Elm2: &FullValue{
+                                               Elm:  &VetRestriction{ID: 
"Sdf"},
+                                               Elm2: 1,
+                                       },
                                        Timestamp: testTimestamp,
                                        Windows:   testWindows,
                                },
@@ -167,7 +204,10 @@ func TestSdfNodes(t *testing.T) {
                                                                Timestamp: 
testTimestamp,
                                                                Windows:   
testWindows,
                                                        },
-                                                       Elm2: 
&VetRestriction{ID: "Sdf.1", SplitRest: true, RestSize: true, Val: 1},
+                                                       Elm2: &FullValue{
+                                                               Elm:  
&VetRestriction{ID: "Sdf.1", SplitRest: true, RestSize: true, Val: 1},
+                                                               Elm2: 1,
+                                                       },
                                                },
                                                Elm2:      1.0,
                                                Timestamp: testTimestamp,
@@ -181,7 +221,10 @@ func TestSdfNodes(t *testing.T) {
                                                                Timestamp: 
testTimestamp,
                                                                Windows:   
testWindows,
                                                        },
-                                                       Elm2: 
&VetRestriction{ID: "Sdf.2", SplitRest: true, RestSize: true, Val: 1},
+                                                       Elm2: &FullValue{
+                                                               Elm:  
&VetRestriction{ID: "Sdf.2", SplitRest: true, RestSize: true, Val: 1},
+                                                               Elm2: 1,
+                                                       },
                                                },
                                                Elm2:      1.0,
                                                Timestamp: testTimestamp,
@@ -199,7 +242,10 @@ func TestSdfNodes(t *testing.T) {
                                                Timestamp: testTimestamp,
                                                Windows:   testWindows,
                                        },
-                                       Elm2:      &VetRestriction{ID: "KvSdf"},
+                                       Elm2: &FullValue{
+                                               Elm:  &VetRestriction{ID: 
"KvSdf"},
+                                               Elm2: false,
+                                       },
                                        Timestamp: testTimestamp,
                                        Windows:   testWindows,
                                },
@@ -212,7 +258,10 @@ func TestSdfNodes(t *testing.T) {
                                                                Timestamp: 
testTimestamp,
                                                                Windows:   
testWindows,
                                                        },
-                                                       Elm2: 
&VetRestriction{ID: "KvSdf.1", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+                                                       Elm2: &FullValue{
+                                                               Elm:  
&VetRestriction{ID: "KvSdf.1", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+                                                               Elm2: false,
+                                                       },
                                                },
                                                Elm2:      3.0,
                                                Timestamp: testTimestamp,
@@ -226,7 +275,10 @@ func TestSdfNodes(t *testing.T) {
                                                                Timestamp: 
testTimestamp,
                                                                Windows:   
testWindows,
                                                        },
-                                                       Elm2: 
&VetRestriction{ID: "KvSdf.2", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+                                                       Elm2: &FullValue{
+                                                               Elm:  
&VetRestriction{ID: "KvSdf.2", SplitRest: true, RestSize: true, Key: 1, Val: 2},
+                                                               Elm2: false,
+                                                       },
                                                },
                                                Elm2:      3.0,
                                                Timestamp: testTimestamp,
@@ -244,7 +296,10 @@ func TestSdfNodes(t *testing.T) {
                                                Timestamp: testTimestamp,
                                                Windows:   testWindows,
                                        },
-                                       Elm2:      &VetRestriction{ID: "Sdf"},
+                                       Elm2: &FullValue{
+                                               Elm:  &VetRestriction{ID: 
"Sdf"},
+                                               Elm2: false,
+                                       },
                                        Timestamp: testTimestamp,
                                        Windows:   testWindows,
                                },
@@ -296,7 +351,10 @@ func TestSdfNodes(t *testing.T) {
                                                Timestamp: testTimestamp,
                                                Windows:   testWindows,
                                        },
-                                       Elm2:      
offsetrange.Restriction{Start: 0, End: 4},
+                                       Elm2: &FullValue{
+                                               Elm:  
offsetrange.Restriction{Start: 0, End: 4},
+                                               Elm2: false,
+                                       },
                                        Timestamp: testTimestamp,
                                        Windows:   testWindows,
                                },
@@ -338,8 +396,33 @@ func TestSdfNodes(t *testing.T) {
                                fn:   dfn,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
+                                       },
+                                       Elm2:      1.0,
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               },
+                               want: FullValue{
+                                       Elm:       &VetRestriction{ID: "Sdf", 
CreateTracker: true, ProcessElm: true, Val: 1},
+                                       Elm2:      nil,
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               },
+                       },
+                       {
+                               name: "SingleElemStatefulWatermarkEstimating",
+                               fn:   statefulWeFn,
+                               in: FullValue{
+                                       Elm: &FullValue{
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: 1,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -363,7 +446,10 @@ func TestSdfNodes(t *testing.T) {
                                                        Timestamp: 
testTimestamp,
                                                        Windows:   testWindows,
                                                },
-                                               Elm2: &VetRestriction{ID: 
"KvSdf"},
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "KvSdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      3.0,
                                        Timestamp: testTimestamp,
@@ -495,6 +581,10 @@ func TestAsSplittableUnit(t *testing.T) {
        if err != nil {
                t.Fatalf("invalid function: %v", err)
        }
+       statefulWeFn, err := graph.NewDoFn(&VetSdfStatefulWatermark{}, 
graph.NumMainInputs(graph.MainSingle))
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
        multiWindows := []typex.Window{
                window.IntervalWindow{Start: 10, End: 20},
                window.IntervalWindow{Start: 11, End: 21},
@@ -537,8 +627,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                // but the element is still built to be valid.
                                elm := FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -549,7 +642,7 @@ func TestAsSplittableUnit(t *testing.T) {
                                n := &ParDo{UID: 1, Fn: dfn, Out: []Node{}}
                                node := 
&ProcessSizedElementsAndRestrictions{PDo: n}
                                node.rt = &SplittableUnitRTracker{
-                                       VetRTracker: VetRTracker{Rest: 
elm.Elm.(*FullValue).Elm2.(*VetRestriction)},
+                                       VetRTracker: VetRTracker{Rest: 
elm.Elm.(*FullValue).Elm2.(*FullValue).Elm.(*VetRestriction)},
                                        Done:        test.doneWork,
                                        Remaining:   test.remainingWork,
                                        ThisIsDone:  false,
@@ -589,8 +682,52 @@ func TestAsSplittableUnit(t *testing.T) {
                                frac: 0.5,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
+                                       },
+                                       Elm2:      1.0,
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               },
+                               wantPrimaries: []*FullValue{{
+                                       Elm: &FullValue{
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
+                                       },
+                                       Elm2:      1.0,
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               }},
+                               wantResiduals: []*FullValue{{
+                                       Elm: &FullValue{
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
+                                       },
+                                       Elm2:      1.0,
+                                       Timestamp: testTimestamp,
+                                       Windows:   testWindows,
+                               }},
+                       },
+                       {
+                               name: "SingleElemStatefulWatermarkEstimating",
+                               fn:   statefulWeFn,
+                               frac: 0.5,
+                               in: FullValue{
+                                       Elm: &FullValue{
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: 0,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -598,8 +735,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                },
                                wantPrimaries: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf.1", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+                                                       Elm2: 1,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -607,8 +747,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                }},
                                wantResiduals: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf.2", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+                                                       Elm2: 1,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -625,7 +768,10 @@ func TestAsSplittableUnit(t *testing.T) {
                                                        Elm:  1,
                                                        Elm2: 2,
                                                },
-                                               Elm2: &VetRestriction{ID: 
"KvSdf"},
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "KvSdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      3.0,
                                        Timestamp: testTimestamp,
@@ -637,7 +783,10 @@ func TestAsSplittableUnit(t *testing.T) {
                                                        Elm:  1,
                                                        Elm2: 2,
                                                },
-                                               Elm2: &VetRestriction{ID: 
"KvSdf.1", RestSize: true, Key: 1, Val: 2},
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "KvSdf.1", RestSize: true, Key: 1, Val: 2},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      3.0,
                                        Timestamp: testTimestamp,
@@ -649,7 +798,10 @@ func TestAsSplittableUnit(t *testing.T) {
                                                        Elm:  1,
                                                        Elm2: 2,
                                                },
-                                               Elm2: &VetRestriction{ID: 
"KvSdf.2", RestSize: true, Key: 1, Val: 2},
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "KvSdf.2", RestSize: true, Key: 1, Val: 2},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      3.0,
                                        Timestamp: testTimestamp,
@@ -663,8 +815,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                frac:   0.5,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -681,8 +836,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                frac: 0.125, // Should be in the middle of the 
first (current) window.
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -690,8 +848,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                },
                                wantPrimaries: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf.1", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf.1", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -699,16 +860,22 @@ func TestAsSplittableUnit(t *testing.T) {
                                }},
                                wantResiduals: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf.2", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf.2", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
                                        Windows:   testMultiWindows[0:1],
                                }, {
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -723,8 +890,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                frac: 0.55,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -732,8 +902,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                },
                                wantPrimaries: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -741,8 +914,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                }},
                                wantResiduals: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -758,8 +934,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                doneRt: true,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -767,8 +946,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                },
                                wantPrimaries: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -776,8 +958,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                }},
                                wantResiduals: []*FullValue{{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf", RestSize: true, Val: 1},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf", RestSize: true, Val: 1},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -792,8 +977,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                frac: 0.95, // Should round to end of element 
and cause a no-op.
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -807,10 +995,10 @@ func TestAsSplittableUnit(t *testing.T) {
                        test := test
                        t.Run(test.name, func(t *testing.T) {
                                // Setup, create transforms, inputs, and 
desired outputs.
-                               n := &ParDo{UID: 1, Fn: test.fn, Out: []Node{}}
+                               n := &ParDo{UID: 1, Fn: test.fn, Out: []Node{}, 
we: &VetWatermarkEstimator{State: 1}}
                                node := 
&ProcessSizedElementsAndRestrictions{PDo: n}
                                node.rt = &SplittableUnitRTracker{
-                                       VetRTracker: VetRTracker{Rest: 
test.in.Elm.(*FullValue).Elm2.(*VetRestriction)},
+                                       VetRTracker: VetRTracker{Rest: 
test.in.Elm.(*FullValue).Elm2.(*FullValue).Elm.(*VetRestriction)},
                                        Done:        0,
                                        Remaining:   1.0,
                                        ThisIsDone:  test.doneRt,
@@ -850,8 +1038,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                fn:   pdfn,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: 
&offsetrange.Restriction{Start: 0, End: 4},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&offsetrange.Restriction{Start: 0, End: 4},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -863,8 +1054,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                fn:   rdfn,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: 
&offsetrange.Restriction{Start: 0, End: 4},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&offsetrange.Restriction{Start: 0, End: 4},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -878,7 +1072,7 @@ func TestAsSplittableUnit(t *testing.T) {
                                // Setup, create transforms, inputs, and 
desired outputs.
                                n := &ParDo{UID: 1, Fn: test.fn, Out: []Node{}}
                                node := 
&ProcessSizedElementsAndRestrictions{PDo: n}
-                               node.rt = 
sdf.RTracker(offsetrange.NewTracker(*test.in.Elm.(*FullValue).Elm2.(*offsetrange.Restriction)))
+                               node.rt = 
sdf.RTracker(offsetrange.NewTracker(*test.in.Elm.(*FullValue).Elm2.(*FullValue).Elm.(*offsetrange.Restriction)))
                                node.elm = &test.in
                                node.numW = len(test.in.Windows)
                                node.currW = 0
@@ -911,8 +1105,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                fn:   dfn,
                                in: FullValue{
                                        Elm: &FullValue{
-                                               Elm:  1,
-                                               Elm2: &VetRestriction{ID: 
"Sdf"},
+                                               Elm: 1,
+                                               Elm2: &FullValue{
+                                                       Elm:  
&VetRestriction{ID: "Sdf"},
+                                                       Elm2: false,
+                                               },
                                        },
                                        Elm2:      1.0,
                                        Timestamp: testTimestamp,
@@ -979,8 +1176,11 @@ func TestMultiWindowProcessing(t *testing.T) {
        // Create a plan with a single valid element as input to ProcessElement.
        in := FullValue{
                Elm: &FullValue{
-                       Elm:  1,
-                       Elm2: offsetrange.Restriction{Start: 0, End: 4},
+                       Elm: 1,
+                       Elm2: &FullValue{
+                               Elm:  offsetrange.Restriction{Start: 0, End: 4},
+                               Elm2: false,
+                       },
                },
                Elm2:      4.0,
                Timestamp: testTimestamp,
diff --git a/sdks/go/pkg/beam/core/runtime/genx/genx.go 
b/sdks/go/pkg/beam/core/runtime/genx/genx.go
index 422ee2900c7..b5bf766c268 100644
--- a/sdks/go/pkg/beam/core/runtime/genx/genx.go
+++ b/sdks/go/pkg/beam/core/runtime/genx/genx.go
@@ -117,6 +117,12 @@ func handleDoFn(fn *graph.DoFn, c cache) {
        }
        c.pullMethod(sdf.CreateWatermarkEstimatorFn())
        c.regType(sdf.WatermarkEstimatorT())
+       if !sdf.IsStatefulWatermarkEstimating() {
+               return
+       }
+       c.pullMethod(sdf.InitialWatermarkEstimatorStateFn())
+       c.pullMethod(sdf.WatermarkEstimatorStateFn())
+       c.regType(sdf.WatermarkEstimatorStateT())
 }
 
 func handleCombineFn(fn *graph.CombineFn, c cache) {
diff --git a/sdks/go/pkg/beam/core/runtime/genx/genx_test.go 
b/sdks/go/pkg/beam/core/runtime/genx/genx_test.go
index cc219d6f295..24f96a5bc7a 100644
--- a/sdks/go/pkg/beam/core/runtime/genx/genx_test.go
+++ b/sdks/go/pkg/beam/core/runtime/genx/genx_test.go
@@ -22,6 +22,7 @@ import (
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
        "github.com/google/go-cmp/cmp"
        "github.com/google/go-cmp/cmp/cmpopts"
@@ -41,6 +42,7 @@ func TestRegisterDoFn(t *testing.T) {
        tO := reflect.TypeOf((*O)(nil)).Elem()
        tRt := reflect.TypeOf((*sdf.LockRTracker)(nil)).Elem()
        tWe := reflect.TypeOf((*sdf.WallTimeWatermarkEstimator)(nil)).Elem()
+       tWes := reflect.TypeOf((*WatermarkEstimatorState)(nil)).Elem()
 
        tests := []struct {
                name   string
@@ -69,7 +71,7 @@ func TestRegisterDoFn(t *testing.T) {
                {"DoFn01 pointer reflect", reflect.TypeOf(&DoFn01{}), true, 
false, []reflect.Type{tDoFn01, tR, tS}},
                {"DoFn02 reflect - filtered types", tDoFn02, true, false, 
[]reflect.Type{tDoFn02}},
                {"CombineFn01 reflect - combine methods", tCmbFn01, true, 
false, []reflect.Type{tCmbFn01, tA, tI, tO}},
-               {"DoFn03 reflect - sdf methods", tDoFn03, true, false, 
[]reflect.Type{tDoFn03, tRt, tWe, tR}},
+               {"DoFn03 reflect - sdf methods", tDoFn03, true, false, 
[]reflect.Type{typex.EventTimeType, tDoFn03, tRt, tWe, tWes, tR}},
                {"DoFn04 reflect - containers", tDoFn04, true, false, 
[]reflect.Type{tDoFn04, tR, tS, tT, tA, tI, tO}},
        }
 
@@ -225,10 +227,20 @@ func (fn *DoFn03) CreateTracker(rest R) *sdf.LockRTracker 
{
        return &sdf.LockRTracker{Rt: RT{}}
 }
 
-func (fn *DoFn03) CreateWatermarkEstimator() *sdf.WallTimeWatermarkEstimator {
+type WatermarkEstimatorState struct{}
+
+func (fn *DoFn03) WatermarkEstimatorState(estimator 
*sdf.WallTimeWatermarkEstimator) WatermarkEstimatorState {
+       return WatermarkEstimatorState{}
+}
+
+func (fn *DoFn03) CreateWatermarkEstimator(state WatermarkEstimatorState) 
*sdf.WallTimeWatermarkEstimator {
        return &sdf.WallTimeWatermarkEstimator{}
 }
 
+func (fn *DoFn03) InitialWatermarkEstimatorState(ts typex.EventTime, rest R, s 
string) WatermarkEstimatorState {
+       return WatermarkEstimatorState{}
+}
+
 type DoFn04 struct{}
 
 func (*DoFn04) ProcessElement([4]R, map[S]T, func(*O) bool, func() func(*I) 
bool, func([]A)) {
diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go
index 45176febf31..aad86b6a02e 100644
--- a/sdks/go/pkg/beam/pardo.go
+++ b/sdks/go/pkg/beam/pardo.go
@@ -17,6 +17,8 @@ package beam
 
 import (
        "fmt"
+       "reflect"
+
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
@@ -64,9 +66,16 @@ func TryParDo(s Scope, dofn interface{}, col PCollection, 
opts ...Option) ([]PCo
        }
 
        var rc *coder.Coder
+       // Sdfs will always encode restrictions as KV<restriction, watermark 
state | bool(false)>
        if fn.IsSplittable() {
                sdf := (*graph.SplittableDoFn)(fn)
-               rc, err = inferCoder(typex.New(sdf.RestrictionT()))
+               restT := typex.New(sdf.RestrictionT())
+               // If no watermark estimator state, use boolean as a placeholder
+               weT := typex.New(reflect.TypeOf(true))
+               if sdf.IsStatefulWatermarkEstimating() {
+                       weT = typex.New(sdf.WatermarkEstimatorStateT())
+               }
+               rc, err = inferCoder(typex.NewKV(restT, weT))
                if err != nil {
                        return nil, addParDoCtx(err, s)
                }

Reply via email to