This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8e1e12453ba [#32929] Add OrderedListState support to Prism. (#33350)
8e1e12453ba is described below
commit 8e1e12453baa8ffa564922496994446c9e41003c
Author: Robert Burke <[email protected]>
AuthorDate: Tue Dec 17 10:45:33 2024 -0800
[#32929] Add OrderedListState support to Prism. (#33350)
---
CHANGES.md | 1 +
runners/prism/java/build.gradle | 4 -
.../pkg/beam/runners/prism/internal/engine/data.go | 97 +++++++++
.../runners/prism/internal/engine/data_test.go | 222 +++++++++++++++++++++
.../prism/internal/engine/elementmanager.go | 11 +-
sdks/go/pkg/beam/runners/prism/internal/execute.go | 2 +-
.../prism/internal/jobservices/management.go | 3 +-
sdks/go/pkg/beam/runners/prism/internal/stage.go | 37 ++++
.../pkg/beam/runners/prism/internal/urns/urns.go | 5 +-
.../beam/runners/prism/internal/worker/worker.go | 14 ++
10 files changed, 385 insertions(+), 11 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 7a8ed493c21..deaa8bfcd47 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -74,6 +74,7 @@
* X feature added (Java/Python)
([#X](https://github.com/apache/beam/issues/X)).
* Support OnWindowExpiration in Prism
([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
+* Support OrderedListState in Prism
([#32929](https://github.com/apache/beam/issues/32929)).
## Breaking Changes
diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle
index ce71151099b..cd2e90fde67 100644
--- a/runners/prism/java/build.gradle
+++ b/runners/prism/java/build.gradle
@@ -233,10 +233,6 @@ def createPrismValidatesRunnerTask = { name,
environmentType ->
excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService'
excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'
- // Not yet implemented in Prism
- // https://github.com/apache/beam/issues/32929
- excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
-
// Not supported in Portable Java SDK yet.
//
https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState
excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
index 7b8689f9511..380b6e2f31d 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
@@ -17,13 +17,17 @@ package engine
import (
"bytes"
+ "cmp"
"fmt"
"log/slog"
+ "slices"
+ "sort"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+ "google.golang.org/protobuf/encoding/protowire"
)
// StateData is a "union" between Bag state and MultiMap state to increase
common code.
@@ -42,6 +46,10 @@ type TimerKey struct {
type TentativeData struct {
Raw map[string][][]byte
+ // stateTypeLen is a map from LinkID to valueLen function for parsing
data.
+ // Only used by OrderedListState, since Prism must manipulate these
datavalues,
+ // which isn't expected, or a requirement of other state values.
+ stateTypeLen map[LinkID]func([]byte) int
// state is a map from transformID + UserStateID, to window, to
userKey, to datavalues.
state map[LinkID]map[typex.Window]map[string]StateData
// timers is a map from the Timer transform+family to the encoded timer.
@@ -220,3 +228,92 @@ func (d *TentativeData) ClearMultimapKeysState(stateID
LinkID, wKey, uKey []byte
kmap[string(uKey)] = StateData{}
slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID),
slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey))
}
+
+// AppendOrderedListState appends the incoming timestamped data to the
existing tentative data bundle.
+// Assumes the data is TimestampedValue encoded, which has a BigEndian int64
suffixed to the data.
+// This means we may always use the last 8 bytes to determine the value
sorting.
+//
+// The stateID has the Transform and Local fields populated, for the Transform
and UserStateID respectively.
+func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey
[]byte, data []byte) {
+ kmap := d.appendState(stateID, wKey)
+ typeLen := d.stateTypeLen[stateID]
+ var datums [][]byte
+
+ // We need to parse out all values individually for later sorting.
+ //
+ // OrderedListState is encoded as KVs with varint encoded millis
followed by the value.
+ // This is not the standard TimestampValueCoder encoding, which
+ // uses a big-endian long as a suffix to the value. This is important
since
+ // values may be concatenated, and we'll need to split them out out.
+ //
+ // The TentativeData.stateTypeLen is populated with a function to
extract
+ // the length of a the next value, so we can skip through elements
individually.
+ for i := 0; i < len(data); {
+ // Get the length of the VarInt for the timestamp.
+ _, tn := protowire.ConsumeVarint(data[i:])
+
+ // Get the length of the encoded value.
+ vn := typeLen(data[i+tn:])
+ prev := i
+ i += tn + vn
+ datums = append(datums, data[prev:i])
+ }
+
+ s := StateData{Bag: append(kmap[string(uKey)].Bag, datums...)}
+ sort.SliceStable(s.Bag, func(i, j int) bool {
+ vi := s.Bag[i]
+ vj := s.Bag[j]
+ return compareTimestampSuffixes(vi, vj)
+ })
+ kmap[string(uKey)] = s
+ slog.Debug("State() OrderedList.Append", slog.Any("StateID", stateID),
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", s))
+}
+
+func compareTimestampSuffixes(vi, vj []byte) bool {
+ ims, _ := protowire.ConsumeVarint(vi)
+ jms, _ := protowire.ConsumeVarint(vj)
+ return (int64(ims)) < (int64(jms))
+}
+
+// GetOrderedListState available state from the tentative bundle data.
+// The stateID has the Transform and Local fields populated, for the Transform
and UserStateID respectively.
+func (d *TentativeData) GetOrderedListState(stateID LinkID, wKey, uKey []byte,
start, end int64) [][]byte {
+ winMap := d.state[stateID]
+ w := d.toWindow(wKey)
+ data := winMap[w][string(uKey)]
+
+ lo, hi := findRange(data.Bag, start, end)
+ slog.Debug("State() OrderedList.Get", slog.Any("StateID", stateID),
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range",
slog.Int64("start", start), slog.Int64("end", end)), slog.Group("outrange",
slog.Int("lo", lo), slog.Int("hi", hi)), slog.Any("Data", data.Bag[lo:hi]))
+ return data.Bag[lo:hi]
+}
+
+func cmpSuffix(vs [][]byte, target int64) func(i int) int {
+ return func(i int) int {
+ v := vs[i]
+ ims, _ := protowire.ConsumeVarint(v)
+ tvsbi := cmp.Compare(target, int64(ims))
+ slog.Debug("cmpSuffix", "target", target, "bi", ims, "tvsbi",
tvsbi)
+ return tvsbi
+ }
+}
+
+func findRange(bag [][]byte, start, end int64) (int, int) {
+ lo, _ := sort.Find(len(bag), cmpSuffix(bag, start))
+ hi, _ := sort.Find(len(bag), cmpSuffix(bag, end))
+ return lo, hi
+}
+
+func (d *TentativeData) ClearOrderedListState(stateID LinkID, wKey, uKey
[]byte, start, end int64) {
+ winMap := d.state[stateID]
+ w := d.toWindow(wKey)
+ kMap := winMap[w]
+ data := kMap[string(uKey)]
+
+ lo, hi := findRange(data.Bag, start, end)
+ slog.Debug("State() OrderedList.Clear", slog.Any("StateID", stateID),
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range",
slog.Int64("start", start), slog.Int64("end", end)), "lo", lo, "hi", hi,
slog.Any("PreClearData", data.Bag))
+
+ cleared := slices.Delete(data.Bag, lo, hi)
+ // Zero the current entry to clear.
+ // Delete makes it difficult to delete the persisted stage state for
the key.
+ kMap[string(uKey)] = StateData{Bag: cleared}
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
new file mode 100644
index 00000000000..1d049710418
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package engine
+
+import (
+ "bytes"
+ "encoding/binary"
+ "math"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "google.golang.org/protobuf/encoding/protowire"
+)
+
+func TestCompareTimestampSuffixes(t *testing.T) {
+ t.Run("simple", func(t *testing.T) {
+ loI := int64(math.MinInt64)
+ hiI := int64(math.MaxInt64)
+
+ loB := binary.BigEndian.AppendUint64(nil, uint64(loI))
+ hiB := binary.BigEndian.AppendUint64(nil, uint64(hiI))
+
+ if compareTimestampSuffixes(loB, hiB) != (loI < hiI) {
+ t.Errorf("lo vs Hi%v < %v: bytes %v vs %v, %v %v", loI,
hiI, loB, hiB, loI < hiI, compareTimestampSuffixes(loB, hiB))
+ }
+ })
+}
+
+func TestOrderedListState(t *testing.T) {
+ time1 := protowire.AppendVarint(nil, 11)
+ time2 := protowire.AppendVarint(nil, 22)
+ time3 := protowire.AppendVarint(nil, 33)
+ time4 := protowire.AppendVarint(nil, 44)
+ time5 := protowire.AppendVarint(nil, 55)
+
+ wKey := []byte{} // global window.
+ uKey := []byte("\u0007userkey")
+ linkID := LinkID{
+ Transform: "dofn",
+ Local: "localStateName",
+ }
+ cc := func(a []byte, b ...byte) []byte {
+ return bytes.Join([][]byte{a, b}, []byte{})
+ }
+
+ t.Run("bool", func(t *testing.T) {
+ d := TentativeData{
+ stateTypeLen: map[LinkID]func([]byte) int{
+ linkID: func(_ []byte) int {
+ return 1
+ },
+ },
+ }
+
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 1))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 1))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0))
+
+ got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want := [][]byte{
+ cc(time1, 1),
+ cc(time2, 0),
+ cc(time3, 1),
+ cc(time4, 0),
+ cc(time5, 1),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList booleans \n%v", d)
+ }
+
+ d.ClearOrderedListState(linkID, wKey, uKey, 12, 54)
+ got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want = [][]byte{
+ cc(time1, 1),
+ cc(time5, 1),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList booleans, after clear\n%v", d)
+ }
+ })
+ t.Run("float64", func(t *testing.T) {
+ d := TentativeData{
+ stateTypeLen: map[LinkID]func([]byte) int{
+ linkID: func(_ []byte) int {
+ return 8
+ },
+ },
+ }
+
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0,
0, 0, 0, 0, 1))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 0, 0, 0,
0, 0, 0, 1, 0))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 0, 0,
0, 0, 1, 0, 0))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0, 0, 0,
0, 1, 0, 0, 0))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 0,
1, 0, 0, 0, 0))
+
+ got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want := [][]byte{
+ cc(time1, 0, 0, 0, 0, 0, 0, 1, 0),
+ cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
+ cc(time3, 0, 0, 0, 0, 0, 1, 0, 0),
+ cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
+ cc(time5, 0, 0, 0, 0, 0, 0, 0, 1),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList float64s \n%v", d)
+ }
+
+ d.ClearOrderedListState(linkID, wKey, uKey, 11, 12)
+ d.ClearOrderedListState(linkID, wKey, uKey, 33, 34)
+ d.ClearOrderedListState(linkID, wKey, uKey, 55, 56)
+
+ got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want = [][]byte{
+ cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
+ cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList float64s, after clear \n%v", d)
+ }
+ })
+
+ t.Run("varint", func(t *testing.T) {
+ d := TentativeData{
+ stateTypeLen: map[LinkID]func([]byte) int{
+ linkID: func(b []byte) int {
+ _, n := protowire.ConsumeVarint(b)
+ return int(n)
+ },
+ },
+ }
+
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time2,
protowire.AppendVarint(nil, 56)...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time4,
protowire.AppendVarint(nil, 20067)...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time3,
protowire.AppendVarint(nil, 7777777)...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time1,
protowire.AppendVarint(nil, 424242)...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time5,
protowire.AppendVarint(nil, 0)...))
+
+ got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want := [][]byte{
+ cc(time1, protowire.AppendVarint(nil, 424242)...),
+ cc(time2, protowire.AppendVarint(nil, 56)...),
+ cc(time3, protowire.AppendVarint(nil, 7777777)...),
+ cc(time4, protowire.AppendVarint(nil, 20067)...),
+ cc(time5, protowire.AppendVarint(nil, 0)...),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList int32 \n%v", d)
+ }
+ })
+ t.Run("lp", func(t *testing.T) {
+ d := TentativeData{
+ stateTypeLen: map[LinkID]func([]byte) int{
+ linkID: func(b []byte) int {
+ l, n := protowire.ConsumeVarint(b)
+ return int(l) + n
+ },
+ },
+ }
+
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time1,
[]byte("\u0003one")...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time2,
[]byte("\u0003two")...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time3,
[]byte("\u0005three")...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time4,
[]byte("\u0004four")...))
+ d.AppendOrderedListState(linkID, wKey, uKey, cc(time5,
[]byte("\u0019FourHundredAndEleventyTwo")...))
+
+ got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want := [][]byte{
+ cc(time1, []byte("\u0003one")...),
+ cc(time2, []byte("\u0003two")...),
+ cc(time3, []byte("\u0005three")...),
+ cc(time4, []byte("\u0004four")...),
+ cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList int32 \n%v", d)
+ }
+ })
+ t.Run("lp_onecall", func(t *testing.T) {
+ d := TentativeData{
+ stateTypeLen: map[LinkID]func([]byte) int{
+ linkID: func(b []byte) int {
+ l, n := protowire.ConsumeVarint(b)
+ return int(l) + n
+ },
+ },
+ }
+ d.AppendOrderedListState(linkID, wKey, uKey,
bytes.Join([][]byte{
+ time5, []byte("\u0019FourHundredAndEleventyTwo"),
+ time3, []byte("\u0005three"),
+ time2, []byte("\u0003two"),
+ time1, []byte("\u0003one"),
+ time4, []byte("\u0004four"),
+ }, nil))
+
+ got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
+ want := [][]byte{
+ cc(time1, []byte("\u0003one")...),
+ cc(time2, []byte("\u0003two")...),
+ cc(time3, []byte("\u0005three")...),
+ cc(time4, []byte("\u0004four")...),
+ cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
+ }
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("OrderedList int32 \n%v", d)
+ }
+ })
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index 00e18c669af..7180bb456f1 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -269,8 +269,10 @@ func (em *ElementManager) StageAggregates(ID string) {
// StageStateful marks the given stage as stateful, which means elements are
// processed by key.
-func (em *ElementManager) StageStateful(ID string) {
- em.stages[ID].stateful = true
+func (em *ElementManager) StageStateful(ID string, stateTypeLen
map[LinkID]func([]byte) int) {
+ ss := em.stages[ID]
+ ss.stateful = true
+ ss.stateTypeLen = stateTypeLen
}
// StageOnWindowExpiration marks the given stage as stateful, which means
elements are
@@ -669,7 +671,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle)
TentativeData {
ss := em.stages[rb.StageID]
ss.mu.Lock()
defer ss.mu.Unlock()
- var ret TentativeData
+ ret := TentativeData{
+ stateTypeLen: ss.stateTypeLen,
+ }
keys := ss.inprogressKeysByBundle[rb.BundleID]
// TODO(lostluck): Also track windows per bundle, to reduce copying.
if len(ss.state) > 0 {
@@ -1136,6 +1140,7 @@ type stageState struct {
inprogressKeys set[string]
// all keys that are assigned to bundles.
inprogressKeysByBundle map[string]set[string]
// bundle to key assignments.
state map[LinkID]map[typex.Window]map[string]StateData
// state data for this stage, from {tid, stateID} -> window -> userKey
+ stateTypeLen map[LinkID]func([]byte) int
// map from state to a function that will produce the total length of a single
value in bytes.
// Accounting for handling watermark holds for timers.
// We track the count of timers with the same hold, and clear it from
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index fde62f00c7c..8b56c30eb61 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -316,7 +316,7 @@ func executePipeline(ctx context.Context, wks
map[string]*worker.W, j *jobservic
sort.Strings(outputs)
em.AddStage(stage.ID, []string{stage.primaryInput},
outputs, stage.sideInputs)
if stage.stateful {
- em.StageStateful(stage.ID)
+ em.StageStateful(stage.ID, stage.stateTypeLen)
}
if stage.onWindowExpiration.TimerFamily != "" {
slog.Debug("OnWindowExpiration",
slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration))
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 894a6e1427a..af559a92ab4 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -174,7 +174,8 @@ func (s *Server) Prepare(ctx context.Context, req
*jobpb.PrepareJobRequest) (_ *
// Validate all the state features
for _, spec := range pardo.GetStateSpecs() {
isStateful = true
- check("StateSpec.Protocol.Urn",
spec.GetProtocol().GetUrn(), urns.UserStateBag, urns.UserStateMultiMap)
+ check("StateSpec.Protocol.Urn",
spec.GetProtocol().GetUrn(),
+ urns.UserStateBag,
urns.UserStateMultiMap, urns.UserStateOrderedList)
}
// Validate all the timer features
for _, spec := range pardo.GetTimerFamilySpecs() {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go
b/sdks/go/pkg/beam/runners/prism/internal/stage.go
index 9dd6cbdafec..e1e942a06f0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/stage.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go
@@ -35,6 +35,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/prototext"
+ "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
)
@@ -73,6 +74,10 @@ type stage struct {
hasTimers []engine.StaticTimerID
processingTimeTimers map[string]bool
+ // stateTypeLen maps state values to encoded lengths for the type.
+ // Only used for OrderedListState which must manipulate individual
state datavalues.
+ stateTypeLen map[engine.LinkID]func([]byte) int
+
exe transformExecuter
inputTransformID string
inputInfo engine.PColInfo
@@ -438,6 +443,38 @@ func buildDescriptor(stg *stage, comps *pipepb.Components,
wk *worker.W, em *eng
rewriteCoder(&s.SetSpec.ElementCoderId)
case *pipepb.StateSpec_OrderedListSpec:
rewriteCoder(&s.OrderedListSpec.ElementCoderId)
+ // Add the length determination helper for
OrderedList state values.
+ if stg.stateTypeLen == nil {
+ stg.stateTypeLen =
map[engine.LinkID]func([]byte) int{}
+ }
+ linkID := engine.LinkID{
+ Transform: tid,
+ Local: stateID,
+ }
+ var fn func([]byte) int
+ switch v :=
coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() {
+ case urns.CoderBool:
+ fn = func(_ []byte) int {
+ return 1
+ }
+ case urns.CoderDouble:
+ fn = func(_ []byte) int {
+ return 8
+ }
+ case urns.CoderVarInt:
+ fn = func(b []byte) int {
+ _, n :=
protowire.ConsumeVarint(b)
+ return int(n)
+ }
+ case urns.CoderLengthPrefix, urns.CoderBytes,
urns.CoderStringUTF8:
+ fn = func(b []byte) int {
+ l, n :=
protowire.ConsumeVarint(b)
+ return int(l) + n
+ }
+ default:
+ rewriteErr = fmt.Errorf("unknown coder
used for ordered list state after re-write id: %v coder: %v, for state %v for
transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID,
tid, stg.ID)
+ }
+ stg.stateTypeLen[linkID] = fn
case *pipepb.StateSpec_CombiningSpec:
rewriteCoder(&s.CombiningSpec.AccumulatorCoderId)
case *pipepb.StateSpec_MapSpec:
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
index 5312fd799c8..12e62ef84a8 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
@@ -95,8 +95,9 @@ var (
SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP)
// UserState kinds
- UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG)
- UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP)
+ UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG)
+ UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP)
+ UserStateOrderedList = usUrn(pipepb.StandardUserStateTypes_ORDERED_LIST)
// WindowsFns
WindowFnGlobal = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index c2c988aa097..9d9058975b2 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -554,6 +554,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer)
error {
case *fnpb.StateKey_MultimapKeysUserState_:
mmkey := key.GetMultimapKeysUserState()
data =
b.OutputData.GetMultimapKeysState(engine.LinkID{Transform:
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(),
mmkey.GetKey())
+ case *fnpb.StateKey_OrderedListUserState_:
+ olkey := key.GetOrderedListUserState()
+ data = b.OutputData.GetOrderedListState(
+ engine.LinkID{Transform:
olkey.GetTransformId(), Local: olkey.GetUserStateId()},
+ olkey.GetWindow(),
olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
default:
panic(fmt.Sprintf("unsupported StateKey
Get type: %T: %v", key.GetType(), prototext.Format(key)))
}
@@ -578,6 +583,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer)
error {
case *fnpb.StateKey_MultimapUserState_:
mmkey := key.GetMultimapUserState()
b.OutputData.AppendMultimapState(engine.LinkID{Transform:
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(),
mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData())
+ case *fnpb.StateKey_OrderedListUserState_:
+ olkey := key.GetOrderedListUserState()
+ b.OutputData.AppendOrderedListState(
+ engine.LinkID{Transform:
olkey.GetTransformId(), Local: olkey.GetUserStateId()},
+ olkey.GetWindow(),
olkey.GetKey(), req.GetAppend().GetData())
default:
panic(fmt.Sprintf("unsupported StateKey
Append type: %T: %v", key.GetType(), prototext.Format(key)))
}
@@ -601,6 +611,10 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer)
error {
case *fnpb.StateKey_MultimapKeysUserState_:
mmkey := key.GetMultimapUserState()
b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform:
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(),
mmkey.GetKey())
+ case *fnpb.StateKey_OrderedListUserState_:
+ olkey := key.GetOrderedListUserState()
+
b.OutputData.ClearOrderedListState(engine.LinkID{Transform:
olkey.GetTransformId(), Local: olkey.GetUserStateId()},
+ olkey.GetWindow(),
olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
default:
panic(fmt.Sprintf("unsupported StateKey
Clear type: %T: %v", key.GetType(), prototext.Format(key)))
}