This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 991b4b517f7 [prism] Fusion base, reshuffle, cogbk. (#27737)
991b4b517f7 is described below

commit 991b4b517f7f0e182caa59742405cd6c0fdfb35f
Author: Robert Burke <[email protected]>
AuthorDate: Mon Jul 31 08:39:43 2023 -0700

    [prism] Fusion base, reshuffle, cogbk. (#27737)
    
    * [prism] Fusion base, reshuffle, cogbk.
    
    * silence logging, better message
    
    * precise reshuffle strategy filtering
    
    * Remove leftover comments
    
    * remove decommissioned function
    
    * fix typos.
    
    ---------
    
    Co-authored-by: lostluck <[email protected]>
---
 sdks/go/pkg/beam/core/runtime/exec/translate.go    |  11 +-
 .../prism/internal/engine/elementmanager.go        |   2 +-
 sdks/go/pkg/beam/runners/prism/internal/execute.go |  27 +-
 .../beam/runners/prism/internal/execute_test.go    | 151 ++++++++-
 .../beam/runners/prism/internal/handlerunner.go    |  70 +++-
 .../prism/internal/jobservices/management.go       |  38 ++-
 .../pkg/beam/runners/prism/internal/preprocess.go  | 249 +++++++++++++-
 .../beam/runners/prism/internal/preprocess_test.go |  10 +-
 sdks/go/pkg/beam/runners/prism/internal/stage.go   | 373 +++++++++++----------
 .../beam/runners/prism/internal/testdofns_test.go  |  14 +
 .../runners/prism/internal/unimplemented_test.go   |  37 +-
 .../pkg/beam/runners/prism/internal/urns/urns.go   |   1 +
 .../beam/runners/prism/internal/worker/worker.go   |   2 +-
 .../beam/runners/universal/extworker/extworker.go  |   4 +-
 .../go/pkg/beam/runners/universal/runnerlib/job.go |   2 +-
 15 files changed, 760 insertions(+), 231 deletions(-)

diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go 
b/sdks/go/pkg/beam/core/runtime/exec/translate.go
index 65827d05838..02a1418880e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go
@@ -193,7 +193,11 @@ func newBuilder(desc *fnpb.ProcessBundleDescriptor) 
(*builder, error) {
 
                input := unmarshalKeyedValues(transform.GetInputs())
                for i, from := range input {
-                       succ[from] = append(succ[from], linkID{id, i})
+                       // We don't need to multiplex successors for pardo side 
inputs.
+                       // so we only do so for SDK side Flattens.
+                       if i == 0 || transform.GetSpec().GetUrn() == 
graphx.URNFlatten {
+                               succ[from] = append(succ[from], linkID{id, i})
+                       }
                }
                output := unmarshalKeyedValues(transform.GetOutputs())
                for _, to := range output {
@@ -731,7 +735,10 @@ func (b *builder) makeLink(from string, id linkID) (Node, 
error) {
                        }
                        // Strip PCollections from Expand nodes, as CoGBK 
metrics are handled by
                        // the DataSource that preceeds them.
-                       trueOut := out[0].(*PCollection).Out
+                       trueOut := out[0]
+                       if pcol, ok := trueOut.(*PCollection); ok {
+                               trueOut = pcol.Out
+                       }
                        b.units = b.units[:len(b.units)-1]
                        u = &Expand{UID: b.idgen.New(), ValueDecoders: 
decoders, Out: trueOut}
 
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index 95ad2e562d4..c8721e1a207 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -570,7 +570,7 @@ func (ss *stageState) startBundle(watermark mtime.Time, 
genBundID func() string)
 
        var toProcess, notYet []element
        for _, e := range ss.pending {
-               if !ss.aggregate || ss.aggregate && 
ss.strat.EarliestCompletion(e.window) <= watermark {
+               if !ss.aggregate || ss.aggregate && 
ss.strat.EarliestCompletion(e.window) < watermark {
                        toProcess = append(toProcess, e)
                } else {
                        notYet = append(notYet, e)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 13c8b2b127c..ecff740ed86 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -60,7 +60,11 @@ func RunPipeline(j *jobservices.Job) {
        j.SendMsg("running " + j.String())
        j.Running()
 
-       executePipeline(j.RootCtx, wk, j)
+       err := executePipeline(j.RootCtx, wk, j)
+       if err != nil {
+               j.Failed(err)
+               return
+       }
        j.SendMsg("pipeline completed " + j.String())
 
        // Stop the worker.
@@ -126,14 +130,14 @@ func externalEnvironment(ctx context.Context, ep 
*pipepb.ExternalPayload, wk *wo
 type transformExecuter interface {
        ExecuteUrns() []string
        ExecuteWith(t *pipepb.PTransform) string
-       ExecuteTransform(tid string, t *pipepb.PTransform, comps 
*pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B
+       ExecuteTransform(stageID, tid string, t *pipepb.PTransform, comps 
*pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B
 }
 
 type processor struct {
        transformExecuters map[string]transformExecuter
 }
 
-func executePipeline(ctx context.Context, wk *worker.W, j *jobservices.Job) {
+func executePipeline(ctx context.Context, wk *worker.W, j *jobservices.Job) 
error {
        pipeline := j.Pipeline
        comps := proto.Clone(pipeline.GetComponents()).(*pipepb.Components)
 
@@ -145,7 +149,8 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) {
                Combine(CombineCharacteristic{EnableLifting: true}),
                ParDo(ParDoCharacteristic{DisableSDF: true}),
                Runner(RunnerCharacteristic{
-                       SDKFlatten: false,
+                       SDKFlatten:   false,
+                       SDKReshuffle: false,
                }),
        }
 
@@ -175,10 +180,7 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) {
        // TODO move this loop and code into the preprocessor instead.
        stages := map[string]*stage{}
        var impulses []string
-       for i, stage := range topo {
-               if len(stage.transforms) != 1 {
-                       panic(fmt.Sprintf("unsupported stage[%d]: contains 
multiple transforms: %v; TODO: implement fusion", i, stage.transforms))
-               }
+       for _, stage := range topo {
                tid := stage.transforms[0]
                t := ts[tid]
                urn := t.GetSpec().GetUrn()
@@ -255,16 +257,16 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) {
                        wk.Descriptors[stage.ID] = stage.desc
                case wk.ID:
                        // Great! this is for this environment. // Broken 
abstraction.
-                       buildStage(stage, tid, t, comps, wk)
+                       buildDescriptor(stage, comps, wk)
                        stages[stage.ID] = stage
                        slog.Debug("pipelineBuild", slog.Group("stage", 
slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName())))
                        outputs := maps.Keys(stage.OutputsToCoders)
                        sort.Strings(outputs)
-                       em.AddStage(stage.ID, []string{stage.mainInputPCol}, 
stage.sides, outputs)
+                       em.AddStage(stage.ID, []string{stage.primaryInput}, 
stage.sides, outputs)
                default:
                        err := fmt.Errorf("unknown environment[%v]", 
t.GetEnvironmentId())
                        slog.Error("Execute", err)
-                       panic(err)
+                       return err
                }
        }
 
@@ -285,6 +287,7 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) {
                }(rb)
        }
        slog.Info("pipeline done!", slog.String("job", j.String()))
+       return nil
 }
 
 func collectionPullDecoder(coldCId string, coders map[string]*pipepb.Coder, 
comps *pipepb.Components) func(io.Reader) []byte {
@@ -300,7 +303,7 @@ func getWindowValueCoders(comps *pipepb.Components, col 
*pipepb.PCollection, cod
 
 func getOnlyValue[K comparable, V any](in map[K]V) V {
        if len(in) != 1 {
-               panic(fmt.Sprintf("expected single value map, had %v", len(in)))
+               panic(fmt.Sprintf("expected single value map, had %v - %v", 
len(in), in))
        }
        for _, v := range in {
                return v
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
index 96639a33015..1a5ae7989a0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
@@ -27,6 +27,7 @@ import (
        "github.com/apache/beam/sdks/v2/go/pkg/beam"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/register"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal"
        
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal"
@@ -319,6 +320,61 @@ func TestRunner_Pipelines(t *testing.T) {
                                        Want: []int{16, 17, 18},
                                }, sum)
                        },
+               }, {
+                       name: "sideinput_sameAsMainInput",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col0 := beam.ParDo(s, dofn1, imp)
+                               sum := beam.ParDo(s, dofn3x1, col0, 
beam.SideInput{Input: col0}, beam.SideInput{Input: col0})
+                               beam.ParDo(s, &int64Check{
+                                       Name: "sum sideinput check",
+                                       Want: []int{13, 14, 15},
+                               }, sum)
+                       },
+               }, {
+                       name: "sideinput_sameAsMainInput+Derived",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col0 := beam.ParDo(s, dofn1, imp)
+                               col1 := beam.ParDo(s, dofn2, col0)
+                               // Doesn't matter which of col0 or col1 is used.
+                               sum := beam.ParDo(s, dofn3x1, col0, 
beam.SideInput{Input: col0}, beam.SideInput{Input: col1})
+                               beam.ParDo(s, &int64Check{
+                                       Name: "sum sideinput check",
+                                       Want: []int{16, 17, 18},
+                               }, sum)
+                       },
+               }, {
+                       // Main input is getting duplicated data, since it's 
being executed twice...
+                       // But that doesn't make any sense
+                       name: "sideinput_2iterable1Data2",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col0 := beam.ParDo(s, dofn1, imp)
+                               col1 := beam.ParDo(s, dofn2, col0)
+                               col2 := beam.ParDo(s, dofn2, col0)
+                               // Doesn't matter which of col1 or col2 is used.
+                               sum := beam.ParDo(s, dofn3x1, col0, 
beam.SideInput{Input: col2}, beam.SideInput{Input: col1})
+                               beam.ParDo(s, &int64Check{
+                                       Name: "iter sideinput check",
+                                       Want: []int{19, 20, 21},
+                               }, sum)
+                       },
+               }, {
+                       // Re-use the same side inputs sequentially (the two 
consumers should be in the same stage.)
+                       name: "sideinput_two_2iterable1Data",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col0 := beam.ParDo(s, dofn1, imp)
+                               sideIn1 := beam.ParDo(s, dofn1, imp)
+                               sideIn2 := beam.ParDo(s, dofn1, imp)
+                               col1 := beam.ParDo(s, dofn3x1, col0, 
beam.SideInput{Input: sideIn1}, beam.SideInput{Input: sideIn2})
+                               sum := beam.ParDo(s, dofn3x1, col1, 
beam.SideInput{Input: sideIn1}, beam.SideInput{Input: sideIn2})
+                               beam.ParDo(s, &int64Check{
+                                       Name: "check_sideinput_re-use",
+                                       Want: []int{25, 26, 27},
+                               }, sum)
+                       },
                }, {
                        name: "combine_perkey",
                        pipeline: func(s beam.Scope) {
@@ -380,6 +436,30 @@ func TestRunner_Pipelines(t *testing.T) {
                                }, flat)
                                passert.NonEmpty(s, flat)
                        },
+               }, {
+                       name: "gbk_into_gbk",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col1 := beam.ParDo(s, dofnKV, imp)
+                               gbk1 := beam.GroupByKey(s, col1)
+                               col2 := beam.ParDo(s, dofnGBKKV, gbk1)
+                               gbk2 := beam.GroupByKey(s, col2)
+                               out := beam.ParDo(s, dofnGBK, gbk2)
+                               passert.Equals(s, out, int64(9), int64(12))
+                       },
+               }, {
+                       name: "lperror_gbk_into_cogbk_shared_input",
+                       pipeline: func(s beam.Scope) {
+                               want := beam.CreateList(s, []int{0})
+                               fruits := beam.CreateList(s, []int64{42, 42, 
42})
+                               fruitsKV := beam.AddFixedKey(s, fruits)
+
+                               fruitsGBK := beam.GroupByKey(s, fruitsKV)
+                               fooKV := beam.ParDo(s, toFoo, fruitsGBK)
+                               fruitsFooCoGBK := beam.CoGroupByKey(s, 
fruitsKV, fooKV)
+                               got := beam.ParDo(s, toID, fruitsFooCoGBK)
+                               passert.Equals(s, got, want)
+                       },
                },
        }
        // TODO: Explicit DoFn Failure case.
@@ -429,8 +509,75 @@ func TestFailure(t *testing.T) {
        if err == nil {
                t.Fatalf("expected pipeline failure, but got a success")
        }
-       // Job failure state reason isn't communicated with the state change 
over the API
-       // so we can't check for a reason here.
+       if want := "doFnFail: failing as intended"; 
!strings.Contains(err.Error(), want) {
+               t.Fatalf("expected pipeline failure with %q, but was %v", want, 
err)
+       }
+}
+
+func TestRunner_Passert(t *testing.T) {
+       initRunner(t)
+       tests := []struct {
+               name     string
+               pipeline func(s beam.Scope)
+               metrics  func(t *testing.T, pr beam.PipelineResult)
+       }{
+               {
+                       name: "Empty",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col1 := beam.ParDo(s, dofnEmpty, imp)
+                               passert.Empty(s, col1)
+                       },
+               }, {
+                       name: "Equals-TwoEmpty",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col1 := beam.ParDo(s, dofnEmpty, imp)
+                               col2 := beam.ParDo(s, dofnEmpty, imp)
+                               passert.Equals(s, col1, col2)
+                       },
+               }, {
+                       name: "Equals",
+                       pipeline: func(s beam.Scope) {
+                               imp := beam.Impulse(s)
+                               col1 := beam.ParDo(s, dofn1, imp)
+                               col2 := beam.ParDo(s, dofn1, imp)
+                               passert.Equals(s, col1, col2)
+                       },
+               },
+       }
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       p, s := beam.NewPipelineWithRoot()
+                       test.pipeline(s)
+                       pr, err := executeWithT(context.Background(), t, p)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if test.metrics != nil {
+                               test.metrics(t, pr)
+                       }
+               })
+       }
+}
+
+func toFoo(et beam.EventTime, id int, _ func(*int64) bool) (int, string) {
+       return id, "ooo"
+}
+
+func toID(et beam.EventTime, id int, fruitIter func(*int64) bool, fooIter 
func(*string) bool) int {
+       var fruit int64
+       for fruitIter(&fruit) {
+       }
+       var foo string
+       for fooIter(&foo) {
+       }
+       return id
+}
+
+func init() {
+       register.Function3x2(toFoo)
+       register.Function4x1(toID)
 }
 
 // TODO: PCollection metrics tests, in particular for element counts, in multi 
transform pipelines
diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go 
b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go
index e841620625e..27303f03b70 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go
@@ -41,8 +41,9 @@ import (
 // RunnerCharacteristic holds the configuration for Runner based transforms,
 // such as GBKs, Flattens.
 type RunnerCharacteristic struct {
-       SDKFlatten bool // Sets whether we should force an SDK side flatten.
-       SDKGBK     bool // Sets whether the GBK should be handled by the SDK, 
if possible by the SDK.
+       SDKFlatten   bool // Sets whether we should force an SDK side flatten.
+       SDKGBK       bool // Sets whether the GBK should be handled by the SDK, 
if possible by the SDK.
+       SDKReshuffle bool
 }
 
 func Runner(config any) *runner {
@@ -63,13 +64,72 @@ func (*runner) ConfigCharacteristic() reflect.Type {
        return reflect.TypeOf((*RunnerCharacteristic)(nil)).Elem()
 }
 
+var _ transformPreparer = (*runner)(nil)
+
+func (*runner) PrepareUrns() []string {
+       return []string{urns.TransformReshuffle}
+}
+
+// PrepareTransform handles special processing with respect runner transforms, 
like reshuffle.
+func (h *runner) PrepareTransform(tid string, t *pipepb.PTransform, comps 
*pipepb.Components) (*pipepb.Components, []string) {
+       // TODO: Implement the windowing strategy the "backup" transforms used 
for Reshuffle.
+       // TODO: Implement a fusion break for reshuffles.
+
+       if h.config.SDKReshuffle {
+               panic("SDK side reshuffle not yet supported")
+       }
+
+       // A Reshuffle, in principle, is a no-op on the pipeline structure, WRT 
correctness.
+       // It could however affect performance, so it exists to tell the runner 
that this
+       // point in the pipeline needs a fusion break, to enable the pipeline 
to change it's
+       // degree of parallelism.
+       //
+       // The change of parallelism goes both ways. It could allow for larger 
batch sizes
+       // enable smaller batch sizes downstream if it is infact paralleizable.
+       //
+       // But for a single transform node per stage runner, we can elide it 
entirely,
+       // since the input collection and output collection types match.
+
+       // Get the input and output PCollections, there should only be 1 each.
+       if len(t.GetInputs()) != 1 {
+               panic("Expected single input PCollection in reshuffle: " + 
prototext.Format(t))
+       }
+       if len(t.GetOutputs()) != 1 {
+               panic("Expected single output PCollection in reshuffle: " + 
prototext.Format(t))
+       }
+
+       inColID := getOnlyValue(t.GetInputs())
+       outColID := getOnlyValue(t.GetOutputs())
+
+       // We need to find all Transforms that consume the output collection and
+       // replace them so they consume the input PCollection directly.
+
+       // We need to remove the consumers of the output PCollection.
+       toRemove := []string{}
+
+       for _, t := range comps.GetTransforms() {
+               for li, gi := range t.GetInputs() {
+                       if gi == outColID {
+                               // The whole s
+                               t.GetInputs()[li] = inColID
+                       }
+               }
+       }
+
+       // And all the sub transforms.
+       toRemove = append(toRemove, t.GetSubtransforms()...)
+
+       // Return the new components which is the transforms consumer
+       return nil, toRemove
+}
+
 var _ transformExecuter = (*runner)(nil)
 
 func (*runner) ExecuteUrns() []string {
-       return []string{urns.TransformFlatten, urns.TransformGBK}
+       return []string{urns.TransformFlatten, urns.TransformGBK, 
urns.TransformReshuffle}
 }
 
-// ExecuteWith returns what environment the
+// ExecuteWith returns what environment the transform should execute in.
 func (h *runner) ExecuteWith(t *pipepb.PTransform) string {
        urn := t.GetSpec().GetUrn()
        if urn == urns.TransformFlatten && !h.config.SDKFlatten {
@@ -82,7 +142,7 @@ func (h *runner) ExecuteWith(t *pipepb.PTransform) string {
 }
 
 // ExecuteTransform handles special processing with respect to runner specific 
transforms
-func (h *runner) ExecuteTransform(tid string, t *pipepb.PTransform, comps 
*pipepb.Components, watermark mtime.Time, inputData [][]byte) *worker.B {
+func (h *runner) ExecuteTransform(stageID, tid string, t *pipepb.PTransform, 
comps *pipepb.Components, watermark mtime.Time, inputData [][]byte) *worker.B {
        urn := t.GetSpec().GetUrn()
        var data [][]byte
        var onlyOut string
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 0c16b5eb34f..953ee50c559 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -24,6 +24,7 @@ import (
        jobpb 
"github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1"
        pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
+       "golang.org/x/exp/maps"
        "golang.org/x/exp/slog"
        "google.golang.org/protobuf/types/known/timestamppb"
 )
@@ -101,7 +102,9 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
        }
 
        // Inspect Transforms for unsupported features.
-       for _, t := range job.Pipeline.GetComponents().GetTransforms() {
+       bypassedWindowingStrategies := map[string]bool{}
+       ts := job.Pipeline.GetComponents().GetTransforms()
+       for _, t := range ts {
                urn := t.GetSpec().GetUrn()
                switch urn {
                case urns.TransformImpulse,
@@ -112,6 +115,23 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
                        urns.TransformAssignWindows:
                // Very few expected transforms types for submitted pipelines.
                // Most URNs are for the runner to communicate back to the SDK 
for execution.
+               case urns.TransformReshuffle:
+                       // Reshuffles use features we don't yet support, but we 
would like to
+                       // support them by making them the no-op they are, and 
be precise about
+                       // what we're ignoring.
+                       var cols []string
+                       for _, stID := range t.GetSubtransforms() {
+                               st := ts[stID]
+                               // Only check the outputs, since reshuffle 
re-instates any previous WindowingStrategy
+                               // so we still validate the strategy used by 
the input, avoiding skips.
+                               cols = append(cols, 
maps.Values(st.GetOutputs())...)
+                       }
+
+                       pcs := job.Pipeline.GetComponents().GetPcollections()
+                       for _, col := range cols {
+                               wsID := pcs[col].GetWindowingStrategyId()
+                               bypassedWindowingStrategies[wsID] = true
+                       }
                case "":
                        // Composites can often have no spec
                        if len(t.GetSubtransforms()) > 0 {
@@ -124,24 +144,26 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
        }
 
        // Inspect Windowing strategies for unsupported features.
-       for _, ws := range 
job.Pipeline.GetComponents().GetWindowingStrategies() {
+       for wsID, ws := range 
job.Pipeline.GetComponents().GetWindowingStrategies() {
                check("WindowingStrategy.AllowedLateness", 
ws.GetAllowedLateness(), int64(0))
                check("WindowingStrategy.ClosingBehaviour", 
ws.GetClosingBehavior(), pipepb.ClosingBehavior_EMIT_IF_NONEMPTY)
                check("WindowingStrategy.AccumulationMode", 
ws.GetAccumulationMode(), pipepb.AccumulationMode_DISCARDING)
                if ws.GetWindowFn().GetUrn() != urns.WindowFnSession {
                        check("WindowingStrategy.MergeStatus", 
ws.GetMergeStatus(), pipepb.MergeStatus_NON_MERGING)
                }
-               check("WindowingStrategy.OnTimerBehavior", 
ws.GetOnTimeBehavior(), pipepb.OnTimeBehavior_FIRE_IF_NONEMPTY)
-               check("WindowingStrategy.OutputTime", ws.GetOutputTime(), 
pipepb.OutputTime_END_OF_WINDOW)
-               // Non nil triggers should fail.
-               if ws.GetTrigger().GetDefault() == nil {
-                       check("WindowingStrategy.Trigger", ws.GetTrigger(), 
&pipepb.Trigger_Default{})
+               if !bypassedWindowingStrategies[wsID] {
+                       check("WindowingStrategy.OnTimeBehavior", 
ws.GetOnTimeBehavior(), pipepb.OnTimeBehavior_FIRE_IF_NONEMPTY)
+                       check("WindowingStrategy.OutputTime", 
ws.GetOutputTime(), pipepb.OutputTime_END_OF_WINDOW)
+                       // Non nil triggers should fail.
+                       if ws.GetTrigger().GetDefault() == nil {
+                               check("WindowingStrategy.Trigger", 
ws.GetTrigger(), &pipepb.Trigger_Default{})
+                       }
                }
        }
        if len(errs) > 0 {
                jErr := &joinError{errs: errs}
                slog.Error("unable to run job", slog.String("cause", 
"unimplemented features"), slog.String("jobname", req.GetJobName()), 
slog.String("errors", jErr.Error()))
-               err := fmt.Errorf("found %v uses of features unimplemented in 
prism in job %v: %v", len(errs), req.GetJobName(), jErr)
+               err := fmt.Errorf("found %v uses of features unimplemented in 
prism in job %v:\n%v", len(errs), req.GetJobName(), jErr)
                job.Failed(err)
                return nil, err
        }
diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go 
b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go
index 8769a05d38f..96c5f5549b0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go
@@ -16,12 +16,15 @@
 package internal
 
 import (
+       "fmt"
        "sort"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
        pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
        "golang.org/x/exp/maps"
        "golang.org/x/exp/slog"
+       "google.golang.org/protobuf/encoding/prototext"
 )
 
 // transformPreparer is an interface for handling different urns in the 
preprocessor
@@ -138,11 +141,253 @@ func (p *preprocessor) preProcessGraph(comps 
*pipepb.Components) []*stage {
        topological := pipelinex.TopologicalSort(ts, keptLeaves)
        slog.Debug("topological transform ordering", topological)
 
+       // Basic Fusion Behavior
+       //
+       // Fusion is the practice of executing associated DoFns in the same 
stage.
+       // This often leads to more efficient processing, since costly 
encode/decode or
+       // serialize/deserialize operations can be elided. In Beam, any 
PCollection can
+       // in principle serve as a place for serializing and deserializing 
elements.
+       //
+       // In particular, Fusion is a stage for optimizing pipeline execution, 
and was
+       // described in the FlumeJava paper, in section 4.
+       // 
https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/35650.pdf
+       //
+       // Per the FlumeJava paper, there are two primary opportunities for 
Fusion,
+       // Producer+Consumer fusion and Sibling fusion.
+       //
+       // Producer+Consumer fusion is when the producer of a PCollection and 
the consumers of
+       // that PCollection are combined into a single stage. Sibling fusion is 
when two consumers
+       // of the same pcollection are fused into the same step. These 
processes can continue until
+       // graph structure or specific transforms dictate that fusion may not 
proceed futher.
+       //
+       // Examples of fusion breaks include GroupByKeys, or requiring side 
inputs to complete
+       // processing for downstream processing, since the producer and 
consumer of side inputs
+       // cannot be in the same fused stage.
+       //
+       // Additionally, at this phase, we can consider different optimizations 
for execution.
+       // For example "Flatten unzipping". In practice, there's no requirement 
for any stages
+       // to have an explicit "Flatten" present in the graph. A flatten can be 
"unzipped",
+       // duplicating the consumming transforms after the flatten, until a 
subsequent fusion break.
+       // This enables additional parallelism by allowing sources to operate 
in their own independant
+       // stages. Beam supports this naturally with the separation of work 
into independant
+       // bundles for execution.
+
+       return defaultFusion(topological, comps)
+}
+
+// defaultFusion is the base strategy for prism, that doesn't seek to optimize 
execution
+// with fused stages. Input is the set of leaf nodes we're going to execute, 
topologically
+// sorted, and the pipeline components.
+//
+// Default fusion behavior: Don't. Prism is intended to test all of Beam, 
which often
+// means for testing purposes, to execute pipelines without optimization.
+//
+// Special Exception to unfused Go SDK pipelines.
+//
+// If a transform, after a GBK step, has a single input with a KV<K, Iter<X>> 
coder
+// and a single output O with a KV<K, Iter<Y>> coder, and if then it must be 
fused with
+// the consumers of O.
+func defaultFusion(topological []string, comps *pipepb.Components) []*stage {
        var stages []*stage
+
+       // TODO figure out a better place to source the PCol Parents/Consumers 
analysis
+       // so we don't keep repeating it.
+
+       pcolParents, pcolConsumers := computPColFacts(topological, comps)
+
+       // Explicitly list the pcollectionID we want to fuse along.
+       fuseWithConsumers := map[string]string{}
+       for _, tid := range topological {
+               t := comps.GetTransforms()[tid]
+
+               // See if this transform has a single input and output
+               if len(t.GetInputs()) != 1 || len(t.GetOutputs()) != 1 {
+                       continue
+               }
+               inputID := getOnlyValue(t.GetInputs())
+               outputID := getOnlyValue(t.GetOutputs())
+
+               parentLink := pcolParents[inputID]
+
+               parent := comps.GetTransforms()[parentLink.transform]
+
+               // Check if the input source is a GBK
+               if parent.GetSpec().GetUrn() != urns.TransformGBK {
+                       continue
+               }
+
+               // Check if the coder is a KV<K, Iter<?>>
+               iCID := comps.GetPcollections()[inputID].GetCoderId()
+               oCID := comps.GetPcollections()[outputID].GetCoderId()
+
+               if checkForExpandCoderPattern(iCID, oCID, comps) {
+                       fuseWithConsumers[tid] = outputID
+               }
+       }
+
+       // Since we iterate in topological order, we're guaranteed to process 
producers before consumers.
+       consumed := map[string]bool{} // Checks if we've already handled a 
transform already due to fusion.
        for _, tid := range topological {
-               stages = append(stages, &stage{
+               if consumed[tid] {
+                       continue
+               }
+               stg := &stage{
                        transforms: []string{tid},
-               })
+               }
+               // TODO validate that fused stages have the same environment.
+               stg.envID = comps.GetTransforms()[tid].EnvironmentId
+
+               stages = append(stages, stg)
+
+               pcolID, ok := fuseWithConsumers[tid]
+               if !ok {
+                       continue
+               }
+               cs := pcolConsumers[pcolID]
+
+               for _, c := range cs {
+                       stg.transforms = append(stg.transforms, c.transform)
+                       consumed[c.transform] = true
+               }
+       }
+
+       for _, stg := range stages {
+               prepareStage(stg, comps, pcolConsumers)
        }
        return stages
 }
+
+// computPColFacts computes a map of PCollectionIDs to their parent 
transforms, and a map of
+// PCollectionIDs to their consuming transforms.
+func computPColFacts(topological []string, comps *pipepb.Components) 
(map[string]link, map[string][]link) {
+       pcolParents := map[string]link{}
+       pcolConsumers := map[string][]link{}
+
+       // Use the topological ids so each PCollection only has a single
+       // parent. We've already pruned out composites at this stage.
+       for _, tID := range topological {
+               t := comps.GetTransforms()[tID]
+               for local, global := range t.GetOutputs() {
+                       pcolParents[global] = link{transform: tID, local: 
local, global: global}
+               }
+               for local, global := range t.GetInputs() {
+                       pcolConsumers[global] = append(pcolConsumers[global], 
link{transform: tID, local: local, global: global})
+               }
+       }
+
+       return pcolParents, pcolConsumers
+}
+
+// We need to see that both coders have this pattern: KV<K, Iter<?>>
+func checkForExpandCoderPattern(in, out string, comps *pipepb.Components) bool 
{
+       isKV := func(id string) bool {
+               return comps.GetCoders()[id].GetSpec().GetUrn() == urns.CoderKV
+       }
+       getComp := func(id string, i int) string {
+               return comps.GetCoders()[id].GetComponentCoderIds()[i]
+       }
+       isIter := func(id string) bool {
+               return comps.GetCoders()[id].GetSpec().GetUrn() == 
urns.CoderIterable
+       }
+       if !isKV(in) || !isKV(out) {
+               return false
+       }
+       // Are the keys identical?
+       if getComp(in, 0) != getComp(out, 0) {
+               return false
+       }
+       // Are both values iterables?
+       if isIter(getComp(in, 1)) && isIter(getComp(out, 1)) {
+               // If so we have the ExpandCoderPattern from the Go SDK. Hurray!
+               return true
+       }
+       return false
+}
+
+// prepareStage does the final pre-processing step for stages:
+//
+// 1. Determining the single parallel input (may be 0 for impulse stages).
+// 2. Determining all outputs to the stages.
+// 3. Determining all side inputs.
+// 4  validating that no side input is fed by an internal PCollection.
+// 4. Check that all transforms are in the same environment or are environment 
agnostic. (TODO for xlang)
+// 5. Validate that only the primary input consuming transform are stateful. 
(Might be able to relax this)
+//
+// Those final steps are necessary to validate that the stage doesn't have any 
issues, WRT retries or similar.
+//
+// A PCollection produced by a transform in this stage is in the output set if 
it's consumed by a transform outside of the stage.
+//
+// Finally, it takes this information and caches it in the stage for simpler 
descriptor construction downstream.
+//
+// Note, this is very similar to the work done WRT composites in 
pipelinex.Normalize.
+func prepareStage(stg *stage, comps *pipepb.Components, pipelineConsumers 
map[string][]link) {
+       // Collect all PCollections involved in this stage.
+       pcolParents, pcolConsumers := computPColFacts(stg.transforms, comps)
+
+       transformSet := map[string]bool{}
+       for _, tid := range stg.transforms {
+               transformSet[tid] = true
+       }
+
+       // Now we can see which consumers (inputs) aren't covered by the 
parents (outputs).
+       mainInputs := map[string]string{}
+       var sideInputs []link
+       inputs := map[string]bool{}
+       for pid, plinks := range pcolConsumers {
+               // Check if this PCollection is generated in this bundle.
+               if _, ok := pcolParents[pid]; ok {
+                       // It is, so we will ignore for now.
+                       continue
+               }
+               // Add this collection to our input set.
+               inputs[pid] = true
+               for _, link := range plinks {
+                       t := comps.GetTransforms()[link.transform]
+                       sis, _ := getSideInputs(t)
+                       if _, ok := sis[link.local]; ok {
+                               sideInputs = append(sideInputs, link)
+                       } else {
+                               mainInputs[link.global] = link.global
+                       }
+               }
+       }
+       outputs := map[string]link{}
+       var internal []string
+       // Look at all PCollections produced in this stage.
+       for pid, link := range pcolParents {
+               // Look at all consumers of this PCollection in the pipeline
+               isInternal := true
+               for _, l := range pipelineConsumers[pid] {
+                       // If the consuming transform isn't in the stage, it's 
an output.
+                       if !transformSet[l.transform] {
+                               isInternal = false
+                               outputs[pid] = link
+                       }
+               }
+               // It's consumed as an output, we already ensure the coder's in 
the set.
+               if isInternal {
+                       internal = append(internal, pid)
+               }
+       }
+
+       stg.internalCols = internal
+       stg.outputs = maps.Values(outputs)
+       stg.sideInputs = sideInputs
+
+       defer func() {
+               if e := recover(); e != nil {
+                       panic(fmt.Sprintf("stage %+v:\n%v\n\n%v", stg, e, 
prototext.Format(comps)))
+               }
+       }()
+
+       // Impulses won't have any inputs.
+       if l := len(mainInputs); l == 1 {
+               stg.primaryInput = getOnlyValue(mainInputs)
+       } else if l > 1 {
+               // Quick check that this is a lone flatten node, which is 
handled runner side anyway
+               // and only sent SDK side as part of a fused stage.
+               if !(len(stg.transforms) == 1 && 
comps.GetTransforms()[stg.transforms[0]].GetSpec().GetUrn() == 
urns.TransformFlatten) {
+                       panic("expected flatten node, but wasn't")
+               }
+       }
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go
index add69a7c767..ba39d024e71 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go
@@ -20,6 +20,7 @@ import (
 
        pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
        "github.com/google/go-cmp/cmp"
+       "github.com/google/go-cmp/cmp/cmpopts"
        "google.golang.org/protobuf/testing/protocmp"
 )
 
@@ -73,7 +74,10 @@ func Test_preprocessor_preProcessGraph(t *testing.T) {
                                Environments:        
map[string]*pipepb.Environment{},
                        },
 
-                       wantStages: []*stage{{transforms: 
[]string{"e1_early"}}, {transforms: []string{"e1_late"}}},
+                       wantStages: []*stage{
+                               {transforms: []string{"e1_early"}, envID: 
"env1",
+                                       outputs: []link{{transform: "e1_early", 
local: "i0", global: "pcol1"}}},
+                               {transforms: []string{"e1_late"}, envID: 
"env1", primaryInput: "pcol1"}},
                        wantComponents: &pipepb.Components{
                                Transforms: map[string]*pipepb.PTransform{
                                        // Original is always kept
@@ -124,11 +128,11 @@ func Test_preprocessor_preProcessGraph(t *testing.T) {
                        pre := 
newPreprocessor([]transformPreparer{&testPreparer{}})
 
                        gotStages := pre.preProcessGraph(test.input)
-                       if diff := cmp.Diff(test.wantStages, gotStages, 
cmp.AllowUnexported(stage{})); diff != "" {
+                       if diff := cmp.Diff(test.wantStages, gotStages, 
cmp.AllowUnexported(stage{}, link{}), cmpopts.EquateEmpty()); diff != "" {
                                t.Errorf("preProcessGraph(%q) stages diff 
(-want,+got)\n%v", test.name, diff)
                        }
 
-                       if diff := cmp.Diff(test.input, test.wantComponents, 
protocmp.Transform()); diff != "" {
+                       if diff := cmp.Diff(test.wantComponents, test.input, 
protocmp.Transform()); diff != "" {
                                t.Errorf("preProcessGraph(%q) components diff 
(-want,+got)\n%v", test.name, diff)
                        }
                })
diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go 
b/sdks/go/pkg/beam/runners/prism/internal/stage.go
index 44f9c1e9d28..e6fe28714b7 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/stage.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go
@@ -36,20 +36,31 @@ import (
        "google.golang.org/protobuf/proto"
 )
 
-// stage represents a fused subgraph.
+// link represents the tuple of a transform, the local id, and the global id 
for
+// that transform's respective input or output. Which it is, is context 
dependant,
+// and not knowable from just the link itself, but can be verified against the 
transform proto.
+type link struct {
+       transform, local, global string
+}
+
+// stage represents a fused subgraph executed in a single environment.
 //
-// TODO: do we guarantee that they are all
-// the same environment at this point, or
-// should that be handled later?
+// TODO: Consider ignoring environment boundaries and making fusion
+// only consider necessary materialization breaks. The data protocol
+// should in principle be able to connect two SDK environments directly
+// instead of going through the runner at all, which would be a small
+// efficiency gain, in runner memory use.
 type stage struct {
-       ID         string
-       transforms []string
+       ID           string
+       transforms   []string
+       primaryInput string   // PCollection used as the parallel input.
+       outputs      []link   // PCollections that must escape this stage.
+       sideInputs   []link   // Non-parallel input PCollections and their 
consumers
+       internalCols []string // PCollections that escape. Used for precise 
coder sending.
+       envID        string
 
-       envID            string
        exe              transformExecuter
-       outputCount      int
        inputTransformID string
-       mainInputPCol    string
        inputInfo        engine.PColInfo
        desc             *fnpb.ProcessBundleDescriptor
        sides            []string
@@ -60,16 +71,19 @@ type stage struct {
 }
 
 func (s *stage) Execute(j *jobservices.Job, wk *worker.W, comps 
*pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) {
-       tid := s.transforms[0]
-       slog.Debug("Execute: starting bundle", "bundle", rb, slog.String("tid", 
tid))
+       slog.Debug("Execute: starting bundle", "bundle", rb)
 
        var b *worker.B
        inputData := em.InputForBundle(rb, s.inputInfo)
        var dataReady <-chan struct{}
        switch s.envID {
        case "": // Runner Transforms
+               if len(s.transforms) != 1 {
+                       panic(fmt.Sprintf("unexpected number of runner 
transforms, want 1: %+v", s))
+               }
+               tid := s.transforms[0]
                // Runner transforms are processed immeadiately.
-               b = s.exe.ExecuteTransform(tid, comps.GetTransforms()[tid], 
comps, rb.Watermark, inputData)
+               b = s.exe.ExecuteTransform(s.ID, tid, 
comps.GetTransforms()[tid], comps, rb.Watermark, inputData)
                b.InstID = rb.BundleID
                slog.Debug("Execute: runner transform", "bundle", rb, 
slog.String("tid", tid))
 
@@ -90,7 +104,7 @@ func (s *stage) Execute(j *jobservices.Job, wk *worker.W, 
comps *pipepb.Componen
                        InputData: inputData,
 
                        SinkToPCollection: s.SinkToPCollection,
-                       OutputCount:       s.outputCount,
+                       OutputCount:       len(s.outputs),
                }
                b.Init()
 
@@ -207,7 +221,7 @@ progress:
                }
        }
        if l := len(residualData); l > 0 {
-               slog.Debug("returned empty residual application", "bundle", rb, 
slog.Int("numResiduals", l), slog.String("pcollection", s.mainInputPCol))
+               slog.Debug("returned empty residual application", "bundle", rb, 
slog.Int("numResiduals", l), slog.String("pcollection", s.primaryInput))
        }
        em.PersistBundle(rb, s.OutputsToCoders, b.OutputData, s.inputInfo, 
residualData, minOutputWatermark)
        b.OutputData = engine.TentativeData{} // Clear the data.
@@ -217,6 +231,7 @@ func getSideInputs(t *pipepb.PTransform) 
(map[string]*pipepb.SideInput, error) {
        if t.GetSpec().GetUrn() != urns.TransformParDo {
                return nil, nil
        }
+       // TODO, memoize this, so we don't need to repeatedly unmarshal.
        pardo := &pipepb.ParDoPayload{}
        if err := 
(proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != 
nil {
                return nil, fmt.Errorf("unable to decode ParDoPayload")
@@ -238,99 +253,103 @@ func portFor(wInCid string, wk *worker.W) []byte {
        return sourcePortBytes
 }
 
-func buildStage(s *stage, tid string, t *pipepb.PTransform, comps 
*pipepb.Components, wk *worker.W) {
-       s.inputTransformID = tid + "_source"
+// buildDescriptor constructs a ProcessBundleDescriptor for bundles of this 
stage.
+//
+// Requirements:
+// * The set of inputs to the stage only include one parallel input.
+// * The side input pcollections are fully qualified with global pcollection 
ID, ingesting transform, and local inputID.
+// * The outputs are fully qualified with global PCollectionID, producing 
transform, and local outputID.
+//
+// It assumes that the side inputs are not sourced from PCollections generated 
by any transform in this stage.
+//
+// Because we need the local ids for routing the sources/sinks information.
+func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W) error 
{
+       // Assume stage has an indicated primary input
 
        coders := map[string]*pipepb.Coder{}
-       transforms := map[string]*pipepb.PTransform{
-               tid: t, // The Transform to Execute!
-       }
+       transforms := map[string]*pipepb.PTransform{}
 
-       sis, err := getSideInputs(t)
-       if err != nil {
-               slog.Error("buildStage: getSide Inputs", err, 
slog.String("transformID", tid))
-               panic(err)
+       for _, tid := range stg.transforms {
+               transforms[tid] = comps.GetTransforms()[tid]
        }
-       var inputInfo engine.PColInfo
-       var sides []string
-       localIdReplacements := map[string]string{}
-       globalIDReplacements := map[string]string{}
-       for local, global := range t.GetInputs() {
-               if _, ok := sis[local]; ok {
-                       col := comps.GetPcollections()[global]
-                       oCID := col.GetCoderId()
-                       nCID := lpUnknownCoders(oCID, coders, comps.GetCoders())
-
-                       sides = append(sides, global)
-                       if oCID != nCID {
-                               // Add a synthetic PCollection set with the new 
coder.
-                               newGlobal := global + "_prismside"
-                               comps.GetPcollections()[newGlobal] = 
&pipepb.PCollection{
-                                       DisplayData:         
col.GetDisplayData(),
-                                       UniqueName:          
col.GetUniqueName(),
-                                       CoderId:             nCID,
-                                       IsBounded:           col.GetIsBounded(),
-                                       WindowingStrategyId: 
col.WindowingStrategyId,
-                               }
-                               localIdReplacements[local] = newGlobal
-                               globalIDReplacements[newGlobal] = global
-                       }
-                       continue
-               }
-               // This id is directly used for the source, but this also copies
-               // coders used by side inputs to the coders map for the bundle, 
so
-               // needs to be run for every ID.
-               wInCid := makeWindowedValueCoder(global, comps, coders)
 
-               // this is the main input
-               transforms[s.inputTransformID] = 
sourceTransform(s.inputTransformID, portFor(wInCid, wk), global)
-               col := comps.GetPcollections()[global]
+       // Start with outputs, since they're simple and uniform.
+       sink2Col := map[string]string{}
+       col2Coders := map[string]engine.PColInfo{}
+       for _, o := range stg.outputs {
+               wOutCid := makeWindowedValueCoder(o.global, comps, coders)
+               sinkID := o.transform + "_" + o.local
+               col := comps.GetPcollections()[o.global]
                ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
                wDec, wEnc := getWindowValueCoders(comps, col, coders)
-               inputInfo = engine.PColInfo{
-                       GlobalID: global,
+               sink2Col[sinkID] = o.global
+               col2Coders[o.global] = engine.PColInfo{
+                       GlobalID: o.global,
                        WDec:     wDec,
                        WEnc:     wEnc,
                        EDec:     ed,
                }
-               // We need to process all inputs to ensure we have all input 
coders, so we must continue.
+               transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, 
wk), o.global)
        }
-       // Update side inputs to point to new PCollection with any replaced 
coders.
-       for l, g := range localIdReplacements {
-               t.GetInputs()[l] = g
+
+       // Then lets do Side Inputs, since they are also uniform.
+       var sides []string
+       var prepareSides []func(b *worker.B, watermark mtime.Time)
+       for _, si := range stg.sideInputs {
+               col := comps.GetPcollections()[si.global]
+               oCID := col.GetCoderId()
+               nCID := lpUnknownCoders(oCID, coders, comps.GetCoders())
+
+               sides = append(sides, si.global)
+               if oCID != nCID {
+                       // Add a synthetic PCollection set with the new coder.
+                       newGlobal := si.global + "_prismside"
+                       comps.GetPcollections()[newGlobal] = 
&pipepb.PCollection{
+                               DisplayData:         col.GetDisplayData(),
+                               UniqueName:          col.GetUniqueName(),
+                               CoderId:             nCID,
+                               IsBounded:           col.GetIsBounded(),
+                               WindowingStrategyId: col.WindowingStrategyId,
+                       }
+                       // Update side inputs to point to new PCollection with 
any replaced coders.
+                       transforms[si.transform].GetInputs()[si.local] = 
newGlobal
+               }
+               prepSide, err := handleSideInput(si.transform, si.local, 
si.global, comps, coders, wk)
+               if err != nil {
+                       slog.Error("buildDescriptor: handleSideInputs", err, 
slog.String("transformID", si.transform))
+                       return err
+               }
+               prepareSides = append(prepareSides, prepSide)
        }
 
-       prepareSides, err := handleSideInputs(t, comps, coders, wk, 
globalIDReplacements)
-       if err != nil {
-               slog.Error("buildStage: handleSideInputs", err, 
slog.String("transformID", tid))
-               panic(err)
+       // Finally, the parallel input, which is it's own special snowflake, 
that needs a datasource.
+       // This id is directly used for the source, but this also copies
+       // coders used by side inputs to the coders map for the bundle, so
+       // needs to be run for every ID.
+       wInCid := makeWindowedValueCoder(stg.primaryInput, comps, coders)
+
+       col := comps.GetPcollections()[stg.primaryInput]
+       ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+       wDec, wEnc := getWindowValueCoders(comps, col, coders)
+       inputInfo := engine.PColInfo{
+               GlobalID: stg.primaryInput,
+               WDec:     wDec,
+               WEnc:     wEnc,
+               EDec:     ed,
        }
 
-       // TODO: We need a new logical PCollection to represent the source
-       // so we can avoid double counting PCollection metrics later.
-       // But this also means replacing the ID for the input in the bundle.
-       sink2Col := map[string]string{}
-       col2Coders := map[string]engine.PColInfo{}
-       for local, global := range t.GetOutputs() {
-               wOutCid := makeWindowedValueCoder(global, comps, coders)
-               sinkID := tid + "_" + local
-               col := comps.GetPcollections()[global]
-               ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
-               wDec, wEnc := getWindowValueCoders(comps, col, coders)
-               sink2Col[sinkID] = global
-               col2Coders[global] = engine.PColInfo{
-                       GlobalID: global,
-                       WDec:     wDec,
-                       WEnc:     wEnc,
-                       EDec:     ed,
-               }
-               transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, 
wk), global)
+       stg.inputTransformID = stg.ID + "_source"
+       transforms[stg.inputTransformID] = 
sourceTransform(stg.inputTransformID, portFor(wInCid, wk), stg.primaryInput)
+
+       // Add coders for internal collections.
+       for _, pid := range stg.internalCols {
+               lpUnknownCoders(comps.GetPcollections()[pid].GetCoderId(), 
coders, comps.GetCoders())
        }
 
        reconcileCoders(coders, comps.GetCoders())
 
        desc := &fnpb.ProcessBundleDescriptor{
-               Id:                  s.ID,
+               Id:                  stg.ID,
                Transforms:          transforms,
                WindowingStrategies: comps.GetWindowingStrategies(),
                Pcollections:        comps.GetPcollections(),
@@ -340,119 +359,103 @@ func buildStage(s *stage, tid string, t 
*pipepb.PTransform, comps *pipepb.Compon
                },
        }
 
-       s.desc = desc
-       s.outputCount = len(t.Outputs)
-       s.prepareSides = prepareSides
-       s.sides = sides
-       s.SinkToPCollection = sink2Col
-       s.OutputsToCoders = col2Coders
-       s.mainInputPCol = inputInfo.GlobalID
-       s.inputInfo = inputInfo
+       stg.desc = desc
+       stg.prepareSides = func(b *worker.B, _ string, watermark mtime.Time) {
+               for _, prep := range prepareSides {
+                       prep(b, watermark)
+               }
+       }
+       stg.sides = sides // List of the global pcollection IDs this stage 
needs to wait on for side inputs.
+       stg.SinkToPCollection = sink2Col
+       stg.OutputsToCoders = col2Coders
+       stg.inputInfo = inputInfo
 
-       wk.Descriptors[s.ID] = s.desc
+       wk.Descriptors[stg.ID] = stg.desc
+       return nil
 }
 
-// handleSideInputs ensures appropriate coders are available to the bundle, 
and prepares a function to stage the data.
-func handleSideInputs(t *pipepb.PTransform, comps *pipepb.Components, coders 
map[string]*pipepb.Coder, wk *worker.W, replacements map[string]string) (func(b 
*worker.B, tid string, watermark mtime.Time), error) {
+// handleSideInput returns a closure that will look up the data for a side 
input appropriate for the given watermark.
+func handleSideInput(tid, local, global string, comps *pipepb.Components, 
coders map[string]*pipepb.Coder, wk *worker.W) (func(b *worker.B, watermark 
mtime.Time), error) {
+       t := comps.GetTransforms()[tid]
        sis, err := getSideInputs(t)
        if err != nil {
                return nil, err
        }
-       var prepSides []func(b *worker.B, tid string, watermark mtime.Time)
-
-       // Get WindowedValue Coders for the transform's input and output 
PCollections.
-       for local, global := range t.GetInputs() {
-               si, ok := sis[local]
-               if !ok {
-                       continue // This is the main input.
-               }
-               // Use the old global ID as the identifier for the data storage
-               // This matches what we do in the rest of the stage layer.
-               if oldGlobal, ok := replacements[global]; ok {
-                       global = oldGlobal
-               }
 
-               // this is a side input
-               switch si.GetAccessPattern().GetUrn() {
-               case urns.SideInputIterable:
-                       slog.Debug("urnSideInputIterable",
-                               slog.String("sourceTransform", 
t.GetUniqueName()),
-                               slog.String("local", local),
-                               slog.String("global", global))
-                       col := comps.GetPcollections()[global]
-                       ed := collectionPullDecoder(col.GetCoderId(), coders, 
comps)
-                       wDec, wEnc := getWindowValueCoders(comps, col, coders)
-                       // May be of zero length, but that's OK. Side inputs 
can be empty.
+       switch si := sis[local]; si.GetAccessPattern().GetUrn() {
+       case urns.SideInputIterable:
+               slog.Debug("urnSideInputIterable",
+                       slog.String("sourceTransform", t.GetUniqueName()),
+                       slog.String("local", local),
+                       slog.String("global", global))
+               col := comps.GetPcollections()[global]
+               ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+               wDec, wEnc := getWindowValueCoders(comps, col, coders)
+               // May be of zero length, but that's OK. Side inputs can be 
empty.
 
-                       global, local := global, local
-                       prepSides = append(prepSides, func(b *worker.B, tid 
string, watermark mtime.Time) {
-                               data := wk.D.GetAllData(global)
+               global, local := global, local
+               return func(b *worker.B, watermark mtime.Time) {
+                       data := wk.D.GetAllData(global)
 
-                               if b.IterableSideInputData == nil {
-                                       b.IterableSideInputData = 
map[string]map[string]map[typex.Window][][]byte{}
-                               }
-                               if _, ok := b.IterableSideInputData[tid]; !ok {
-                                       b.IterableSideInputData[tid] = 
map[string]map[typex.Window][][]byte{}
-                               }
-                               b.IterableSideInputData[tid][local] = 
collateByWindows(data, watermark, wDec, wEnc,
-                                       func(r io.Reader) [][]byte {
-                                               return [][]byte{ed(r)}
-                                       }, func(a, b [][]byte) [][]byte {
-                                               return append(a, b...)
-                                       })
-                       })
-
-               case urns.SideInputMultiMap:
-                       slog.Debug("urnSideInputMultiMap",
-                               slog.String("sourceTransform", 
t.GetUniqueName()),
-                               slog.String("local", local),
-                               slog.String("global", global))
-                       col := comps.GetPcollections()[global]
-
-                       kvc := comps.GetCoders()[col.GetCoderId()]
-                       if kvc.GetSpec().GetUrn() != urns.CoderKV {
-                               return nil, fmt.Errorf("multimap side inputs 
needs KV coder, got %v", kvc.GetSpec().GetUrn())
+                       if b.IterableSideInputData == nil {
+                               b.IterableSideInputData = 
map[string]map[string]map[typex.Window][][]byte{}
                        }
+                       if _, ok := b.IterableSideInputData[tid]; !ok {
+                               b.IterableSideInputData[tid] = 
map[string]map[typex.Window][][]byte{}
+                       }
+                       b.IterableSideInputData[tid][local] = 
collateByWindows(data, watermark, wDec, wEnc,
+                               func(r io.Reader) [][]byte {
+                                       return [][]byte{ed(r)}
+                               }, func(a, b [][]byte) [][]byte {
+                                       return append(a, b...)
+                               })
+               }, nil
+
+       case urns.SideInputMultiMap:
+               slog.Debug("urnSideInputMultiMap",
+                       slog.String("sourceTransform", t.GetUniqueName()),
+                       slog.String("local", local),
+                       slog.String("global", global))
+               col := comps.GetPcollections()[global]
 
-                       kd := 
collectionPullDecoder(kvc.GetComponentCoderIds()[0], coders, comps)
-                       vd := 
collectionPullDecoder(kvc.GetComponentCoderIds()[1], coders, comps)
-                       wDec, wEnc := getWindowValueCoders(comps, col, coders)
-
-                       global, local := global, local
-                       prepSides = append(prepSides, func(b *worker.B, tid 
string, watermark mtime.Time) {
-                               // May be of zero length, but that's OK. Side 
inputs can be empty.
-                               data := wk.D.GetAllData(global)
-                               if b.MultiMapSideInputData == nil {
-                                       b.MultiMapSideInputData = 
map[string]map[string]map[typex.Window]map[string][][]byte{}
-                               }
-                               if _, ok := b.MultiMapSideInputData[tid]; !ok {
-                                       b.MultiMapSideInputData[tid] = 
map[string]map[typex.Window]map[string][][]byte{}
-                               }
-                               b.MultiMapSideInputData[tid][local] = 
collateByWindows(data, watermark, wDec, wEnc,
-                                       func(r io.Reader) map[string][][]byte {
-                                               kb := kd(r)
-                                               return map[string][][]byte{
-                                                       string(kb): {vd(r)},
-                                               }
-                                       }, func(a, b map[string][][]byte) 
map[string][][]byte {
-                                               if len(a) == 0 {
-                                                       return b
-                                               }
-                                               for k, vs := range b {
-                                                       a[k] = append(a[k], 
vs...)
-                                               }
-                                               return a
-                                       })
-                       })
-               default:
-                       return nil, fmt.Errorf("local input %v (global %v) uses 
accesspattern %v", local, global, si.GetAccessPattern().GetUrn())
+               kvc := comps.GetCoders()[col.GetCoderId()]
+               if kvc.GetSpec().GetUrn() != urns.CoderKV {
+                       return nil, fmt.Errorf("multimap side inputs needs KV 
coder, got %v", kvc.GetSpec().GetUrn())
                }
+
+               kd := collectionPullDecoder(kvc.GetComponentCoderIds()[0], 
coders, comps)
+               vd := collectionPullDecoder(kvc.GetComponentCoderIds()[1], 
coders, comps)
+               wDec, wEnc := getWindowValueCoders(comps, col, coders)
+
+               global, local := global, local
+               return func(b *worker.B, watermark mtime.Time) {
+                       // May be of zero length, but that's OK. Side inputs 
can be empty.
+                       data := wk.D.GetAllData(global)
+                       if b.MultiMapSideInputData == nil {
+                               b.MultiMapSideInputData = 
map[string]map[string]map[typex.Window]map[string][][]byte{}
+                       }
+                       if _, ok := b.MultiMapSideInputData[tid]; !ok {
+                               b.MultiMapSideInputData[tid] = 
map[string]map[typex.Window]map[string][][]byte{}
+                       }
+                       b.MultiMapSideInputData[tid][local] = 
collateByWindows(data, watermark, wDec, wEnc,
+                               func(r io.Reader) map[string][][]byte {
+                                       kb := kd(r)
+                                       return map[string][][]byte{
+                                               string(kb): {vd(r)},
+                                       }
+                               }, func(a, b map[string][][]byte) 
map[string][][]byte {
+                                       if len(a) == 0 {
+                                               return b
+                                       }
+                                       for k, vs := range b {
+                                               a[k] = append(a[k], vs...)
+                                       }
+                                       return a
+                               })
+               }, nil
+       default:
+               return nil, fmt.Errorf("local input %v (global %v) uses 
accesspattern %v", local, global, si.GetAccessPattern().GetUrn())
        }
-       return func(b *worker.B, tid string, watermark mtime.Time) {
-               for _, prep := range prepSides {
-                       prep(b, tid, watermark)
-               }
-       }, nil
 }
 
 func sourceTransform(parentID string, sourcePortBytes []byte, outPID string) 
*pipepb.PTransform {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go
index f5e8ba12a55..334d74fcae1 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go
@@ -32,6 +32,7 @@ import (
 // Test DoFns are registered in the test file, to allow them to be pruned
 // by the compiler outside of test use.
 func init() {
+       register.Function2x0(dofnEmpty)
        register.Function2x0(dofn1)
        register.Function2x0(dofn1kv)
        register.Function3x0(dofn1x2)
@@ -49,6 +50,8 @@ func init() {
        register.Function2x0(dofnKV2)
        register.Function3x0(dofnGBK)
        register.Function3x0(dofnGBK2)
+       register.Function3x0(dofnGBKKV)
+       register.Emitter2[string, int64]()
        register.DoFn3x0[beam.Window, int64, func(int64)]((*int64Check)(nil))
        register.DoFn2x0[string, func(string)]((*stringCheck)(nil))
        register.Function2x0(dofnKV3)
@@ -64,6 +67,9 @@ func init() {
        register.Emitter2[int64, int64]()
 }
 
+func dofnEmpty(imp []byte, emit func(int64)) {
+}
+
 func dofn1(imp []byte, emit func(int64)) {
        emit(1)
        emit(2)
@@ -237,6 +243,14 @@ func dofnGBK2(k int64, vs func(*string) bool, emit 
func(string)) {
        emit(sum)
 }
 
+func dofnGBKKV(k string, vs func(*int64) bool, emit func(string, int64)) {
+       var v, sum int64
+       for vs(&v) {
+               sum += v
+       }
+       emit(k, sum)
+}
+
 type testRow struct {
        A string
        B int64
diff --git a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
index 8746507a9c0..f738a299cfd 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
@@ -43,9 +43,6 @@ func TestUnimplemented(t *testing.T) {
        }{
                // These tests don't terminate, so can't be run.
                // {pipeline: primitives.Drain}, // Can't test drain 
automatically yet.
-               // {pipeline: primitives.Checkpoints},  // Doesn't self 
terminate?
-               // {pipeline: primitives.Flatten}, // Times out, should be 
quick.
-               // {pipeline: primitives.FlattenDup}, // Times out, should be 
quick.
 
                {pipeline: primitives.TestStreamBoolSequence},
                {pipeline: primitives.TestStreamByteSliceSequence},
@@ -72,10 +69,6 @@ func TestUnimplemented(t *testing.T) {
                {pipeline: primitives.TriggerOrFinally},
                {pipeline: primitives.TriggerRepeat},
 
-               // Reshuffle (Due to missing windowing strategy features)
-               {pipeline: primitives.Reshuffle},
-               {pipeline: primitives.ReshuffleKV},
-
                // State API
                {pipeline: primitives.BagStateParDo},
                {pipeline: primitives.BagStateParDoClear},
@@ -102,3 +95,33 @@ func TestUnimplemented(t *testing.T) {
                })
        }
 }
+
+// TODO move these to a more appropriate location.
+// Mostly placed here to have structural parity with the above test
+// and make it easy to move them to a "it works" expectation.
+func TestImplemented(t *testing.T) {
+       initRunner(t)
+
+       tests := []struct {
+               pipeline func(s beam.Scope)
+       }{
+               {pipeline: primitives.Reshuffle},
+               {pipeline: primitives.Flatten},
+               {pipeline: primitives.FlattenDup},
+               {pipeline: primitives.Checkpoints},
+
+               {pipeline: primitives.CoGBK},
+               {pipeline: primitives.ReshuffleKV},
+       }
+
+       for _, test := range tests {
+               t.Run(intTestName(test.pipeline), func(t *testing.T) {
+                       p, s := beam.NewPipelineWithRoot()
+                       test.pipeline(s)
+                       _, err := executeWithT(context.Background(), t, p)
+                       if err != nil {
+                               t.Fatalf("pipeline failed, but feature should 
be implemented in Prism: %v", err)
+                       }
+               })
+       }
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go 
b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
index 7a5fee21fc7..9fc2c1a923c 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
@@ -57,6 +57,7 @@ var (
        // SDK transforms.
        TransformParDo                = ptUrn(pipepb.StandardPTransforms_PAR_DO)
        TransformCombinePerKey        = 
ctUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY)
+       TransformReshuffle            = 
ctUrn(pipepb.StandardPTransforms_RESHUFFLE)
        TransformPreCombine           = 
cmbtUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY_PRECOMBINE)
        TransformMerge                = 
cmbtUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY_MERGE_ACCUMULATORS)
        TransformExtract              = 
cmbtUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY_EXTRACT_OUTPUTS)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index 80bdadc5162..eefab54a54c 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -95,7 +95,7 @@ func New(id string) *W {
 
                D: &DataService{},
        }
-       slog.Info("Serving Worker components", slog.String("endpoint", 
wk.Endpoint()))
+       slog.Debug("Serving Worker components", slog.String("endpoint", 
wk.Endpoint()))
        fnpb.RegisterBeamFnControlServer(wk.server, wk)
        fnpb.RegisterBeamFnDataServer(wk.server, wk)
        fnpb.RegisterBeamFnLoggingServer(wk.server, wk)
diff --git a/sdks/go/pkg/beam/runners/universal/extworker/extworker.go 
b/sdks/go/pkg/beam/runners/universal/extworker/extworker.go
index 6dab9ebbfb0..a7fc308d219 100644
--- a/sdks/go/pkg/beam/runners/universal/extworker/extworker.go
+++ b/sdks/go/pkg/beam/runners/universal/extworker/extworker.go
@@ -63,7 +63,7 @@ type Loopback struct {
 
 // StartWorker initializes a new worker harness, implementing 
BeamFnExternalWorkerPoolServer.StartWorker.
 func (s *Loopback) StartWorker(ctx context.Context, req 
*fnpb.StartWorkerRequest) (*fnpb.StartWorkerResponse, error) {
-       log.Infof(ctx, "starting worker %v", req.GetWorkerId())
+       log.Debugf(ctx, "starting worker %v", req.GetWorkerId())
        s.mu.Lock()
        defer s.mu.Unlock()
        if s.workers == nil {
@@ -136,7 +136,7 @@ func (s *Loopback) StopWorker(ctx context.Context, req 
*fnpb.StopWorkerRequest)
 func (s *Loopback) Stop(ctx context.Context) error {
        s.mu.Lock()
 
-       log.Infof(ctx, "stopping Loopback, and %d workers", len(s.workers))
+       log.Debugf(ctx, "stopping Loopback, and %d workers", len(s.workers))
        s.workers = nil
        s.rootCancel()
 
diff --git a/sdks/go/pkg/beam/runners/universal/runnerlib/job.go 
b/sdks/go/pkg/beam/runners/universal/runnerlib/job.go
index 5752b33892b..8cbb274e184 100644
--- a/sdks/go/pkg/beam/runners/universal/runnerlib/job.go
+++ b/sdks/go/pkg/beam/runners/universal/runnerlib/job.go
@@ -76,7 +76,7 @@ func Prepare(ctx context.Context, client 
jobpb.JobServiceClient, p *pipepb.Pipel
        }
        resp, err := client.Prepare(ctx, req)
        if err != nil {
-               return "", "", "", errors.Wrap(err, "failed to connect to job 
service")
+               return "", "", "", errors.Wrap(err, "job failed to prepare")
        }
        return resp.GetPreparationId(), 
resp.GetArtifactStagingEndpoint().GetUrl(), resp.GetStagingSessionToken(), nil
 }

Reply via email to