This is an automated email from the ASF dual-hosted git repository.
damondouglas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c6c3fd0490a Handle MultimapKeysSideInput in State GetRequests (#31632)
c6c3fd0490a is described below
commit c6c3fd0490ad73af1f4775d84de18c4ba8fb7af0
Author: Damon <[email protected]>
AuthorDate: Thu Jun 20 20:21:02 2024 -0700
Handle MultimapKeysSideInput in State GetRequests (#31632)
* Handle MultimapKeysSideInput in State GetRequests
* Assign data to keys
* Fix test name
* Fix import sort
---
.../beam/runners/prism/internal/worker/worker.go | 15 ++++
.../runners/prism/internal/worker/worker_test.go | 95 ++++++++++++++++++++++
2 files changed, 110 insertions(+)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index 47fc2cccfc5..d8eb4c96149 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -468,6 +468,21 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer)
error {
data = winMap[w]
+ case *fnpb.StateKey_MultimapKeysSideInput_:
+ mmkey := key.GetMultimapKeysSideInput()
+ wKey := mmkey.GetWindow()
+ var w typex.Window =
window.GlobalWindow{}
+ if len(wKey) > 0 {
+ w, err =
exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
+ if err != nil {
+
panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey,
err))
+ }
+ }
+ winMap :=
b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(),
Local: mmkey.GetSideInputId()}]
+ for k := range winMap[w] {
+ data = append(data, []byte(k))
+ }
+
case *fnpb.StateKey_MultimapSideInput_:
mmkey := key.GetMultimapSideInput()
wKey := mmkey.GetWindow()
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
index b87667eef38..e5b03214ae0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
@@ -18,12 +18,16 @@ package worker
import (
"bytes"
"context"
+ "github.com/google/go-cmp/cmp"
"net"
+ "sort"
"sync"
"testing"
"time"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
@@ -97,6 +101,23 @@ func serveTestWorker(t *testing.T) (context.Context, *W,
*grpc.ClientConn) {
return ctx, w, clientConn
}
+type closeSend func()
+
+func serveTestWorkerStateStream(t *testing.T) (*W,
fnpb.BeamFnState_StateClient, closeSend) {
+ ctx, wk, clientConn := serveTestWorker(t)
+
+ stateCli := fnpb.NewBeamFnStateClient(clientConn)
+ stateStream, err := stateCli.State(ctx)
+ if err != nil {
+ t.Fatal("couldn't create state client:", err)
+ }
+ return wk, stateStream, func() {
+ if err := stateStream.CloseSend(); err != nil {
+ t.Errorf("stateStream.CloseSend() = %v", err)
+ }
+ }
+}
+
func TestWorker_Logging(t *testing.T) {
ctx, _, clientConn := serveTestWorker(t)
@@ -291,3 +312,77 @@ func TestWorker_State_Iterable(t *testing.T) {
t.Errorf("stateStream.CloseSend() = %v", err)
}
}
+
+func TestWorker_State_MultimapKeysSideInput(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ w typex.Window
+ }{
+ {
+ name: "global window",
+ w: window.GlobalWindow{},
+ },
+ {
+ name: "interval window",
+ w: window.IntervalWindow{
+ Start: 1000,
+ End: 2000,
+ },
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ var encW []byte
+ if !tt.w.Equals(window.GlobalWindow{}) {
+ buf := bytes.Buffer{}
+ if err :=
exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err
!= nil {
+ t.Fatalf("error encoding window: %v,
err: %v", tt.w, err)
+ }
+ encW = buf.Bytes()
+ }
+ wk, stateStream, done := serveTestWorkerStateStream(t)
+ defer done()
+ instID := wk.NextInst()
+ wk.activeInstructions[instID] = &B{
+ MultiMapSideInputData:
map[SideInputKey]map[typex.Window]map[string][][]byte{
+ SideInputKey{
+ TransformID: "transformID",
+ Local: "i1",
+ }: {
+ tt.w: map[string][][]byte{"a":
{{1}}, "b": {{2}}},
+ },
+ },
+ }
+
+ stateStream.Send(&fnpb.StateRequest{
+ Id: "first",
+ InstructionId: instID,
+ Request: &fnpb.StateRequest_Get{
+ Get: &fnpb.StateGetRequest{},
+ },
+ StateKey: &fnpb.StateKey{Type:
&fnpb.StateKey_MultimapKeysSideInput_{
+ MultimapKeysSideInput:
&fnpb.StateKey_MultimapKeysSideInput{
+ TransformId: "transformID",
+ SideInputId: "i1",
+ Window: encW,
+ },
+ }},
+ })
+
+ resp, err := stateStream.Recv()
+ if err != nil {
+ t.Fatal("couldn't receive state response:", err)
+ }
+
+ want := []int{97, 98}
+ var got []int
+ for _, b := range resp.GetGet().GetData() {
+ got = append(got, int(b))
+ }
+ sort.Ints(got)
+
+ if !cmp.Equal(got, want) {
+ t.Errorf("didn't receive expected state
response data: got %v, want %v", got, want)
+ }
+ })
+ }
+}