This is an automated email from the ASF dual-hosted git repository.

damondouglas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c6c3fd0490a Handle MultimapKeysSideInput in State GetRequests (#31632)
c6c3fd0490a is described below

commit c6c3fd0490ad73af1f4775d84de18c4ba8fb7af0
Author: Damon <[email protected]>
AuthorDate: Thu Jun 20 20:21:02 2024 -0700

    Handle MultimapKeysSideInput in State GetRequests (#31632)
    
    * Handle MultimapKeysSideInput in State GetRequests
    
    * Assign data to keys
    
    * Fix test name
    
    * Fix import sort
---
 .../beam/runners/prism/internal/worker/worker.go   | 15 ++++
 .../runners/prism/internal/worker/worker_test.go   | 95 ++++++++++++++++++++++
 2 files changed, 110 insertions(+)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index 47fc2cccfc5..d8eb4c96149 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -468,6 +468,21 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
 
                                        data = winMap[w]
 
+                               case *fnpb.StateKey_MultimapKeysSideInput_:
+                                       mmkey := key.GetMultimapKeysSideInput()
+                                       wKey := mmkey.GetWindow()
+                                       var w typex.Window = 
window.GlobalWindow{}
+                                       if len(wKey) > 0 {
+                                               w, err = 
exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
+                                               if err != nil {
+                                                       
panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, 
err))
+                                               }
+                                       }
+                                       winMap := 
b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), 
Local: mmkey.GetSideInputId()}]
+                                       for k := range winMap[w] {
+                                               data = append(data, []byte(k))
+                                       }
+
                                case *fnpb.StateKey_MultimapSideInput_:
                                        mmkey := key.GetMultimapSideInput()
                                        wKey := mmkey.GetWindow()
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
index b87667eef38..e5b03214ae0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
@@ -18,12 +18,16 @@ package worker
 import (
        "bytes"
        "context"
+       "github.com/google/go-cmp/cmp"
        "net"
+       "sort"
        "sync"
        "testing"
        "time"
 
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
        fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
        
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
@@ -97,6 +101,23 @@ func serveTestWorker(t *testing.T) (context.Context, *W, 
*grpc.ClientConn) {
        return ctx, w, clientConn
 }
 
+type closeSend func()
+
+func serveTestWorkerStateStream(t *testing.T) (*W, 
fnpb.BeamFnState_StateClient, closeSend) {
+       ctx, wk, clientConn := serveTestWorker(t)
+
+       stateCli := fnpb.NewBeamFnStateClient(clientConn)
+       stateStream, err := stateCli.State(ctx)
+       if err != nil {
+               t.Fatal("couldn't create state client:", err)
+       }
+       return wk, stateStream, func() {
+               if err := stateStream.CloseSend(); err != nil {
+                       t.Errorf("stateStream.CloseSend() = %v", err)
+               }
+       }
+}
+
 func TestWorker_Logging(t *testing.T) {
        ctx, _, clientConn := serveTestWorker(t)
 
@@ -291,3 +312,77 @@ func TestWorker_State_Iterable(t *testing.T) {
                t.Errorf("stateStream.CloseSend() = %v", err)
        }
 }
+
+func TestWorker_State_MultimapKeysSideInput(t *testing.T) {
+       for _, tt := range []struct {
+               name string
+               w    typex.Window
+       }{
+               {
+                       name: "global window",
+                       w:    window.GlobalWindow{},
+               },
+               {
+                       name: "interval window",
+                       w: window.IntervalWindow{
+                               Start: 1000,
+                               End:   2000,
+                       },
+               },
+       } {
+               t.Run(tt.name, func(t *testing.T) {
+                       var encW []byte
+                       if !tt.w.Equals(window.GlobalWindow{}) {
+                               buf := bytes.Buffer{}
+                               if err := 
exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err 
!= nil {
+                                       t.Fatalf("error encoding window: %v, 
err: %v", tt.w, err)
+                               }
+                               encW = buf.Bytes()
+                       }
+                       wk, stateStream, done := serveTestWorkerStateStream(t)
+                       defer done()
+                       instID := wk.NextInst()
+                       wk.activeInstructions[instID] = &B{
+                               MultiMapSideInputData: 
map[SideInputKey]map[typex.Window]map[string][][]byte{
+                                       SideInputKey{
+                                               TransformID: "transformID",
+                                               Local:       "i1",
+                                       }: {
+                                               tt.w: map[string][][]byte{"a": 
{{1}}, "b": {{2}}},
+                                       },
+                               },
+                       }
+
+                       stateStream.Send(&fnpb.StateRequest{
+                               Id:            "first",
+                               InstructionId: instID,
+                               Request: &fnpb.StateRequest_Get{
+                                       Get: &fnpb.StateGetRequest{},
+                               },
+                               StateKey: &fnpb.StateKey{Type: 
&fnpb.StateKey_MultimapKeysSideInput_{
+                                       MultimapKeysSideInput: 
&fnpb.StateKey_MultimapKeysSideInput{
+                                               TransformId: "transformID",
+                                               SideInputId: "i1",
+                                               Window:      encW,
+                                       },
+                               }},
+                       })
+
+                       resp, err := stateStream.Recv()
+                       if err != nil {
+                               t.Fatal("couldn't receive state response:", err)
+                       }
+
+                       want := []int{97, 98}
+                       var got []int
+                       for _, b := range resp.GetGet().GetData() {
+                               got = append(got, int(b))
+                       }
+                       sort.Ints(got)
+
+                       if !cmp.Equal(got, want) {
+                               t.Errorf("didn't receive expected state 
response data: got %v, want %v", got, want)
+                       }
+               })
+       }
+}

Reply via email to