This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e1e12453ba [#32929] Add OrderedListState support to Prism. (#33350)
8e1e12453ba is described below

commit 8e1e12453baa8ffa564922496994446c9e41003c
Author: Robert Burke <[email protected]>
AuthorDate: Tue Dec 17 10:45:33 2024 -0800

    [#32929] Add OrderedListState support to Prism. (#33350)
---
 CHANGES.md                                         |   1 +
 runners/prism/java/build.gradle                    |   4 -
 .../pkg/beam/runners/prism/internal/engine/data.go |  97 +++++++++
 .../runners/prism/internal/engine/data_test.go     | 222 +++++++++++++++++++++
 .../prism/internal/engine/elementmanager.go        |  11 +-
 sdks/go/pkg/beam/runners/prism/internal/execute.go |   2 +-
 .../prism/internal/jobservices/management.go       |   3 +-
 sdks/go/pkg/beam/runners/prism/internal/stage.go   |  37 ++++
 .../pkg/beam/runners/prism/internal/urns/urns.go   |   5 +-
 .../beam/runners/prism/internal/worker/worker.go   |  14 ++
 10 files changed, 385 insertions(+), 11 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 7a8ed493c21..deaa8bfcd47 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -74,6 +74,7 @@
 * X feature added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
 * Support OnWindowExpiration in Prism 
([#32211](https://github.com/apache/beam/issues/32211)).
   * This enables initial Java GroupIntoBatches support.
+* Support OrderedListState in Prism 
([#32929](https://github.com/apache/beam/issues/32929)).
 
 ## Breaking Changes
 
diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle
index ce71151099b..cd2e90fde67 100644
--- a/runners/prism/java/build.gradle
+++ b/runners/prism/java/build.gradle
@@ -233,10 +233,6 @@ def createPrismValidatesRunnerTask = { name, 
environmentType ->
       excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService'
       excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'
 
-      // Not yet implemented in Prism
-      // https://github.com/apache/beam/issues/32929
-      excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
-
       // Not supported in Portable Java SDK yet.
       // 
https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState
       excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
index 7b8689f9511..380b6e2f31d 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
@@ -17,13 +17,17 @@ package engine
 
 import (
        "bytes"
+       "cmp"
        "fmt"
        "log/slog"
+       "slices"
+       "sort"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+       "google.golang.org/protobuf/encoding/protowire"
 )
 
 // StateData is a "union" between Bag state and MultiMap state to increase 
common code.
@@ -42,6 +46,10 @@ type TimerKey struct {
 type TentativeData struct {
        Raw map[string][][]byte
 
+       // stateTypeLen is a map from LinkID to valueLen function for parsing 
data.
+       // Only used by OrderedListState, since Prism must manipulate these 
datavalues,
+       // which isn't expected, or a requirement of other state values.
+       stateTypeLen map[LinkID]func([]byte) int
        // state is a map from transformID + UserStateID, to window, to 
userKey, to datavalues.
        state map[LinkID]map[typex.Window]map[string]StateData
        // timers is a map from the Timer transform+family to the encoded timer.
@@ -220,3 +228,92 @@ func (d *TentativeData) ClearMultimapKeysState(stateID 
LinkID, wKey, uKey []byte
        kmap[string(uKey)] = StateData{}
        slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey))
 }
+
+// AppendOrderedListState appends the incoming timestamped data to the 
existing tentative data bundle.
+// Assumes the data is TimestampedValue encoded, which has a BigEndian int64 
suffixed to the data.
+// This means we may always use the last 8 bytes to determine the value 
sorting.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey 
[]byte, data []byte) {
+       kmap := d.appendState(stateID, wKey)
+       typeLen := d.stateTypeLen[stateID]
+       var datums [][]byte
+
+       // We need to parse out all values individually for later sorting.
+       //
+       // OrderedListState is encoded as KVs with varint encoded millis 
followed by the value.
+       // This is not the standard TimestampValueCoder encoding, which
+       // uses a big-endian long as a suffix to the value. This is important 
since
+       // values may be concatenated, and we'll need to split them out out.
+       //
+       // The TentativeData.stateTypeLen is populated with a function to 
extract
+       // the length of a the next value, so we can skip through elements 
individually.
+       for i := 0; i < len(data); {
+               // Get the length of the VarInt for the timestamp.
+               _, tn := protowire.ConsumeVarint(data[i:])
+
+               // Get the length of the encoded value.
+               vn := typeLen(data[i+tn:])
+               prev := i
+               i += tn + vn
+               datums = append(datums, data[prev:i])
+       }
+
+       s := StateData{Bag: append(kmap[string(uKey)].Bag, datums...)}
+       sort.SliceStable(s.Bag, func(i, j int) bool {
+               vi := s.Bag[i]
+               vj := s.Bag[j]
+               return compareTimestampSuffixes(vi, vj)
+       })
+       kmap[string(uKey)] = s
+       slog.Debug("State() OrderedList.Append", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", s))
+}
+
+func compareTimestampSuffixes(vi, vj []byte) bool {
+       ims, _ := protowire.ConsumeVarint(vi)
+       jms, _ := protowire.ConsumeVarint(vj)
+       return (int64(ims)) < (int64(jms))
+}
+
+// GetOrderedListState available state from the tentative bundle data.
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) GetOrderedListState(stateID LinkID, wKey, uKey []byte, 
start, end int64) [][]byte {
+       winMap := d.state[stateID]
+       w := d.toWindow(wKey)
+       data := winMap[w][string(uKey)]
+
+       lo, hi := findRange(data.Bag, start, end)
+       slog.Debug("State() OrderedList.Get", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", 
slog.Int64("start", start), slog.Int64("end", end)), slog.Group("outrange", 
slog.Int("lo", lo), slog.Int("hi", hi)), slog.Any("Data", data.Bag[lo:hi]))
+       return data.Bag[lo:hi]
+}
+
+func cmpSuffix(vs [][]byte, target int64) func(i int) int {
+       return func(i int) int {
+               v := vs[i]
+               ims, _ := protowire.ConsumeVarint(v)
+               tvsbi := cmp.Compare(target, int64(ims))
+               slog.Debug("cmpSuffix", "target", target, "bi", ims, "tvsbi", 
tvsbi)
+               return tvsbi
+       }
+}
+
+func findRange(bag [][]byte, start, end int64) (int, int) {
+       lo, _ := sort.Find(len(bag), cmpSuffix(bag, start))
+       hi, _ := sort.Find(len(bag), cmpSuffix(bag, end))
+       return lo, hi
+}
+
+func (d *TentativeData) ClearOrderedListState(stateID LinkID, wKey, uKey 
[]byte, start, end int64) {
+       winMap := d.state[stateID]
+       w := d.toWindow(wKey)
+       kMap := winMap[w]
+       data := kMap[string(uKey)]
+
+       lo, hi := findRange(data.Bag, start, end)
+       slog.Debug("State() OrderedList.Clear", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", 
slog.Int64("start", start), slog.Int64("end", end)), "lo", lo, "hi", hi, 
slog.Any("PreClearData", data.Bag))
+
+       cleared := slices.Delete(data.Bag, lo, hi)
+       // Zero the current entry to clear.
+       // Delete makes it difficult to delete the persisted stage state for 
the key.
+       kMap[string(uKey)] = StateData{Bag: cleared}
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
new file mode 100644
index 00000000000..1d049710418
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package engine
+
+import (
+       "bytes"
+       "encoding/binary"
+       "math"
+       "testing"
+
+       "github.com/google/go-cmp/cmp"
+       "google.golang.org/protobuf/encoding/protowire"
+)
+
+func TestCompareTimestampSuffixes(t *testing.T) {
+       t.Run("simple", func(t *testing.T) {
+               loI := int64(math.MinInt64)
+               hiI := int64(math.MaxInt64)
+
+               loB := binary.BigEndian.AppendUint64(nil, uint64(loI))
+               hiB := binary.BigEndian.AppendUint64(nil, uint64(hiI))
+
+               if compareTimestampSuffixes(loB, hiB) != (loI < hiI) {
+                       t.Errorf("lo vs Hi%v < %v: bytes %v vs %v, %v %v", loI, 
hiI, loB, hiB, loI < hiI, compareTimestampSuffixes(loB, hiB))
+               }
+       })
+}
+
+func TestOrderedListState(t *testing.T) {
+       time1 := protowire.AppendVarint(nil, 11)
+       time2 := protowire.AppendVarint(nil, 22)
+       time3 := protowire.AppendVarint(nil, 33)
+       time4 := protowire.AppendVarint(nil, 44)
+       time5 := protowire.AppendVarint(nil, 55)
+
+       wKey := []byte{} // global window.
+       uKey := []byte("\u0007userkey")
+       linkID := LinkID{
+               Transform: "dofn",
+               Local:     "localStateName",
+       }
+       cc := func(a []byte, b ...byte) []byte {
+               return bytes.Join([][]byte{a, b}, []byte{})
+       }
+
+       t.Run("bool", func(t *testing.T) {
+               d := TentativeData{
+                       stateTypeLen: map[LinkID]func([]byte) int{
+                               linkID: func(_ []byte) int {
+                                       return 1
+                               },
+                       },
+               }
+
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 1))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 1))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0))
+
+               got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want := [][]byte{
+                       cc(time1, 1),
+                       cc(time2, 0),
+                       cc(time3, 1),
+                       cc(time4, 0),
+                       cc(time5, 1),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList booleans \n%v", d)
+               }
+
+               d.ClearOrderedListState(linkID, wKey, uKey, 12, 54)
+               got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want = [][]byte{
+                       cc(time1, 1),
+                       cc(time5, 1),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList booleans, after clear\n%v", d)
+               }
+       })
+       t.Run("float64", func(t *testing.T) {
+               d := TentativeData{
+                       stateTypeLen: map[LinkID]func([]byte) int{
+                               linkID: func(_ []byte) int {
+                                       return 8
+                               },
+                       },
+               }
+
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0, 
0, 0, 0, 0, 1))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 0, 0, 0, 
0, 0, 0, 1, 0))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 0, 0, 
0, 0, 1, 0, 0))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0, 0, 0, 
0, 1, 0, 0, 0))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 0, 
1, 0, 0, 0, 0))
+
+               got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want := [][]byte{
+                       cc(time1, 0, 0, 0, 0, 0, 0, 1, 0),
+                       cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
+                       cc(time3, 0, 0, 0, 0, 0, 1, 0, 0),
+                       cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
+                       cc(time5, 0, 0, 0, 0, 0, 0, 0, 1),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList float64s \n%v", d)
+               }
+
+               d.ClearOrderedListState(linkID, wKey, uKey, 11, 12)
+               d.ClearOrderedListState(linkID, wKey, uKey, 33, 34)
+               d.ClearOrderedListState(linkID, wKey, uKey, 55, 56)
+
+               got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want = [][]byte{
+                       cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
+                       cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList float64s, after clear \n%v", d)
+               }
+       })
+
+       t.Run("varint", func(t *testing.T) {
+               d := TentativeData{
+                       stateTypeLen: map[LinkID]func([]byte) int{
+                               linkID: func(b []byte) int {
+                                       _, n := protowire.ConsumeVarint(b)
+                                       return int(n)
+                               },
+                       },
+               }
+
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 
protowire.AppendVarint(nil, 56)...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 
protowire.AppendVarint(nil, 20067)...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 
protowire.AppendVarint(nil, 7777777)...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 
protowire.AppendVarint(nil, 424242)...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 
protowire.AppendVarint(nil, 0)...))
+
+               got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want := [][]byte{
+                       cc(time1, protowire.AppendVarint(nil, 424242)...),
+                       cc(time2, protowire.AppendVarint(nil, 56)...),
+                       cc(time3, protowire.AppendVarint(nil, 7777777)...),
+                       cc(time4, protowire.AppendVarint(nil, 20067)...),
+                       cc(time5, protowire.AppendVarint(nil, 0)...),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList int32 \n%v", d)
+               }
+       })
+       t.Run("lp", func(t *testing.T) {
+               d := TentativeData{
+                       stateTypeLen: map[LinkID]func([]byte) int{
+                               linkID: func(b []byte) int {
+                                       l, n := protowire.ConsumeVarint(b)
+                                       return int(l) + n
+                               },
+                       },
+               }
+
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 
[]byte("\u0003one")...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 
[]byte("\u0003two")...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 
[]byte("\u0005three")...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 
[]byte("\u0004four")...))
+               d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 
[]byte("\u0019FourHundredAndEleventyTwo")...))
+
+               got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want := [][]byte{
+                       cc(time1, []byte("\u0003one")...),
+                       cc(time2, []byte("\u0003two")...),
+                       cc(time3, []byte("\u0005three")...),
+                       cc(time4, []byte("\u0004four")...),
+                       cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList int32 \n%v", d)
+               }
+       })
+       t.Run("lp_onecall", func(t *testing.T) {
+               d := TentativeData{
+                       stateTypeLen: map[LinkID]func([]byte) int{
+                               linkID: func(b []byte) int {
+                                       l, n := protowire.ConsumeVarint(b)
+                                       return int(l) + n
+                               },
+                       },
+               }
+               d.AppendOrderedListState(linkID, wKey, uKey, 
bytes.Join([][]byte{
+                       time5, []byte("\u0019FourHundredAndEleventyTwo"),
+                       time3, []byte("\u0005three"),
+                       time2, []byte("\u0003two"),
+                       time1, []byte("\u0003one"),
+                       time4, []byte("\u0004four"),
+               }, nil))
+
+               got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+               want := [][]byte{
+                       cc(time1, []byte("\u0003one")...),
+                       cc(time2, []byte("\u0003two")...),
+                       cc(time3, []byte("\u0005three")...),
+                       cc(time4, []byte("\u0004four")...),
+                       cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
+               }
+               if d := cmp.Diff(want, got); d != "" {
+                       t.Errorf("OrderedList int32 \n%v", d)
+               }
+       })
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index 00e18c669af..7180bb456f1 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -269,8 +269,10 @@ func (em *ElementManager) StageAggregates(ID string) {
 
 // StageStateful marks the given stage as stateful, which means elements are
 // processed by key.
-func (em *ElementManager) StageStateful(ID string) {
-       em.stages[ID].stateful = true
+func (em *ElementManager) StageStateful(ID string, stateTypeLen 
map[LinkID]func([]byte) int) {
+       ss := em.stages[ID]
+       ss.stateful = true
+       ss.stateTypeLen = stateTypeLen
 }
 
 // StageOnWindowExpiration marks the given stage as stateful, which means 
elements are
@@ -669,7 +671,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle) 
TentativeData {
        ss := em.stages[rb.StageID]
        ss.mu.Lock()
        defer ss.mu.Unlock()
-       var ret TentativeData
+       ret := TentativeData{
+               stateTypeLen: ss.stateTypeLen,
+       }
        keys := ss.inprogressKeysByBundle[rb.BundleID]
        // TODO(lostluck): Also track windows per bundle, to reduce copying.
        if len(ss.state) > 0 {
@@ -1136,6 +1140,7 @@ type stageState struct {
        inprogressKeys         set[string]                                      
// all keys that are assigned to bundles.
        inprogressKeysByBundle map[string]set[string]                           
// bundle to key assignments.
        state                  map[LinkID]map[typex.Window]map[string]StateData 
// state data for this stage, from {tid, stateID} -> window -> userKey
+       stateTypeLen           map[LinkID]func([]byte) int                      
// map from state to a function that will produce the total length of a single 
value in bytes.
 
        // Accounting for handling watermark holds for timers.
        // We track the count of timers with the same hold, and clear it from
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index fde62f00c7c..8b56c30eb61 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -316,7 +316,7 @@ func executePipeline(ctx context.Context, wks 
map[string]*worker.W, j *jobservic
                        sort.Strings(outputs)
                        em.AddStage(stage.ID, []string{stage.primaryInput}, 
outputs, stage.sideInputs)
                        if stage.stateful {
-                               em.StageStateful(stage.ID)
+                               em.StageStateful(stage.ID, stage.stateTypeLen)
                        }
                        if stage.onWindowExpiration.TimerFamily != "" {
                                slog.Debug("OnWindowExpiration", 
slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration))
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 894a6e1427a..af559a92ab4 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -174,7 +174,8 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (_ *
                        // Validate all the state features
                        for _, spec := range pardo.GetStateSpecs() {
                                isStateful = true
-                               check("StateSpec.Protocol.Urn", 
spec.GetProtocol().GetUrn(), urns.UserStateBag, urns.UserStateMultiMap)
+                               check("StateSpec.Protocol.Urn", 
spec.GetProtocol().GetUrn(),
+                                       urns.UserStateBag, 
urns.UserStateMultiMap, urns.UserStateOrderedList)
                        }
                        // Validate all the timer features
                        for _, spec := range pardo.GetTimerFamilySpecs() {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go 
b/sdks/go/pkg/beam/runners/prism/internal/stage.go
index 9dd6cbdafec..e1e942a06f0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/stage.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go
@@ -35,6 +35,7 @@ import (
        
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker"
        "golang.org/x/exp/maps"
        "google.golang.org/protobuf/encoding/prototext"
+       "google.golang.org/protobuf/encoding/protowire"
        "google.golang.org/protobuf/proto"
 )
 
@@ -73,6 +74,10 @@ type stage struct {
        hasTimers            []engine.StaticTimerID
        processingTimeTimers map[string]bool
 
+       // stateTypeLen maps state values to encoded lengths for the type.
+       // Only used for OrderedListState which must manipulate individual 
state datavalues.
+       stateTypeLen map[engine.LinkID]func([]byte) int
+
        exe              transformExecuter
        inputTransformID string
        inputInfo        engine.PColInfo
@@ -438,6 +443,38 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, 
wk *worker.W, em *eng
                                rewriteCoder(&s.SetSpec.ElementCoderId)
                        case *pipepb.StateSpec_OrderedListSpec:
                                rewriteCoder(&s.OrderedListSpec.ElementCoderId)
+                               // Add the length determination helper for 
OrderedList state values.
+                               if stg.stateTypeLen == nil {
+                                       stg.stateTypeLen = 
map[engine.LinkID]func([]byte) int{}
+                               }
+                               linkID := engine.LinkID{
+                                       Transform: tid,
+                                       Local:     stateID,
+                               }
+                               var fn func([]byte) int
+                               switch v := 
coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() {
+                               case urns.CoderBool:
+                                       fn = func(_ []byte) int {
+                                               return 1
+                                       }
+                               case urns.CoderDouble:
+                                       fn = func(_ []byte) int {
+                                               return 8
+                                       }
+                               case urns.CoderVarInt:
+                                       fn = func(b []byte) int {
+                                               _, n := 
protowire.ConsumeVarint(b)
+                                               return int(n)
+                                       }
+                               case urns.CoderLengthPrefix, urns.CoderBytes, 
urns.CoderStringUTF8:
+                                       fn = func(b []byte) int {
+                                               l, n := 
protowire.ConsumeVarint(b)
+                                               return int(l) + n
+                                       }
+                               default:
+                                       rewriteErr = fmt.Errorf("unknown coder 
used for ordered list state after re-write id: %v coder: %v, for state %v for 
transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID, 
tid, stg.ID)
+                               }
+                               stg.stateTypeLen[linkID] = fn
                        case *pipepb.StateSpec_CombiningSpec:
                                
rewriteCoder(&s.CombiningSpec.AccumulatorCoderId)
                        case *pipepb.StateSpec_MapSpec:
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go 
b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
index 5312fd799c8..12e62ef84a8 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
@@ -95,8 +95,9 @@ var (
        SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP)
 
        // UserState kinds
-       UserStateBag      = usUrn(pipepb.StandardUserStateTypes_BAG)
-       UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP)
+       UserStateBag         = usUrn(pipepb.StandardUserStateTypes_BAG)
+       UserStateMultiMap    = usUrn(pipepb.StandardUserStateTypes_MULTIMAP)
+       UserStateOrderedList = usUrn(pipepb.StandardUserStateTypes_ORDERED_LIST)
 
        // WindowsFns
        WindowFnGlobal  = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index c2c988aa097..9d9058975b2 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -554,6 +554,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
                                case *fnpb.StateKey_MultimapKeysUserState_:
                                        mmkey := key.GetMultimapKeysUserState()
                                        data = 
b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey())
+                               case *fnpb.StateKey_OrderedListUserState_:
+                                       olkey := key.GetOrderedListUserState()
+                                       data = b.OutputData.GetOrderedListState(
+                                               engine.LinkID{Transform: 
olkey.GetTransformId(), Local: olkey.GetUserStateId()},
+                                               olkey.GetWindow(), 
olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
                                default:
                                        panic(fmt.Sprintf("unsupported StateKey 
Get type: %T: %v", key.GetType(), prototext.Format(key)))
                                }
@@ -578,6 +583,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
                                case *fnpb.StateKey_MultimapUserState_:
                                        mmkey := key.GetMultimapUserState()
                                        
b.OutputData.AppendMultimapState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData())
+                               case *fnpb.StateKey_OrderedListUserState_:
+                                       olkey := key.GetOrderedListUserState()
+                                       b.OutputData.AppendOrderedListState(
+                                               engine.LinkID{Transform: 
olkey.GetTransformId(), Local: olkey.GetUserStateId()},
+                                               olkey.GetWindow(), 
olkey.GetKey(), req.GetAppend().GetData())
                                default:
                                        panic(fmt.Sprintf("unsupported StateKey 
Append type: %T: %v", key.GetType(), prototext.Format(key)))
                                }
@@ -601,6 +611,10 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
                                case *fnpb.StateKey_MultimapKeysUserState_:
                                        mmkey := key.GetMultimapUserState()
                                        
b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey())
+                               case *fnpb.StateKey_OrderedListUserState_:
+                                       olkey := key.GetOrderedListUserState()
+                                       
b.OutputData.ClearOrderedListState(engine.LinkID{Transform: 
olkey.GetTransformId(), Local: olkey.GetUserStateId()},
+                                               olkey.GetWindow(), 
olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
                                default:
                                        panic(fmt.Sprintf("unsupported StateKey 
Clear type: %T: %v", key.GetType(), prototext.Format(key)))
                                }

Reply via email to