[ 
https://issues.apache.org/jira/browse/BEAM-3612?focusedWorklogId=170763&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-170763
 ]

ASF GitHub Bot logged work on BEAM-3612:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 29/Nov/18 18:48
            Start Date: 29/Nov/18 18:48
    Worklog Time Spent: 10m 
      Work Description: aaltay closed pull request #7161: [BEAM-3612] Closurize 
method invocations
URL: https://github.com/apache/beam/pull/7161
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/go/pkg/beam/core/graph/fn.go 
b/sdks/go/pkg/beam/core/graph/fn.go
index 6ed05a1019c2..29c92f821ed9 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -101,6 +101,17 @@ func NewFn(fn interface{}) (*Fn, error) {
 
        case reflect.Struct:
                methods := make(map[string]*funcx.Fn)
+               if methodsFuncs, ok := reflectx.WrapMethods(fn); ok {
+                       for name, mfn := range methodsFuncs {
+                               f, err := funcx.New(mfn)
+                               if err != nil {
+                                       return nil, fmt.Errorf("method %v 
invalid: %v", name, err)
+                               }
+                               methods[name] = f
+                       }
+                       return &Fn{Recv: fn, methods: methods}, nil
+               }
+               // TODO(lostluck): Consider moving this into the reflectx 
package.
                for i := 0; i < val.Type().NumMethod(); i++ {
                        m := val.Type().Method(i)
                        if m.PkgPath != "" {
diff --git a/sdks/go/pkg/beam/core/runtime/exec/fn_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/fn_test.go
index 2b1b96ea96a2..2e8a47cace38 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/fn_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/fn_test.go
@@ -505,21 +505,33 @@ func BenchmarkMethodCalls(b *testing.B) {
 
        indirectFunc := reflect.ValueOf(WhatsB).Interface().(func(int) int)
 
-       nrF := fV.Method(0)
-       nrFi := nrF.Interface().(func(int) int)
-       rxnrF := reflectx.MakeFunc(nrFi)
-       rx0x1nrF := reflectx.ToFunc1x1(rxnrF)
-       shimnrF := funcMakerInt(nrFi)             // as if this shim were 
registered
-       shim0x1nrF := reflectx.ToFunc1x1(shimnrF) // would be MakeFunc0x1 if 
registered
-
-       wrF := fV.Type().Method(0).Func
-       wrFi := wrF.Interface().(func(*Foo, int) int)
-
-       rxF := reflectx.MakeFunc(wrFi)
-       rx1x1F := reflectx.ToFunc2x1(rxF)
-       shimF := funcMakerFooRInt(wrFi)       // as if this shim were registered
-       shim1x1F := reflectx.ToFunc2x1(shimF) // would be MakeFunc1x1 if 
registered
-
+       // Implicit Receivers
+       impRF := fV.Method(0)
+       impRFi := impRF.Interface().(func(int) int)
+       impRxF := reflectx.MakeFunc(impRFi)
+       impRx1x1F := reflectx.ToFunc1x1(impRxF)
+       impRShimF := funcMakerInt(impRFi)             // as if this shim were 
registered
+       impRShim1x1F := reflectx.ToFunc1x1(impRShimF) // would be MakeFunc1x1 
if registered
+
+       // Explicit Receivers
+       expRF := fV.Type().Method(0).Func
+       expRFi := expRF.Interface().(func(*Foo, int) int)
+
+       expRxF := reflectx.MakeFunc(expRFi)
+       expRx2c1F := reflectx.ToFunc2x1(expRxF)
+       expRShimF := funcMakerFooRInt(expRFi)         // as if this shim were 
registered
+       expRShim2x1F := reflectx.ToFunc2x1(expRShimF) // would be MakeFunc2x1 
if registered
+
+       // Closured Receivers
+       wrappedWhatsA := func(a int) int { return f.WhatsA(a) }
+       clsrRF := reflect.ValueOf(wrappedWhatsA)
+       clsrRFi := clsrRF.Interface().(func(int) int)
+       clsrRxF := reflectx.MakeFunc(clsrRFi)
+       clsrRx1x1F := reflectx.ToFunc1x1(clsrRxF)
+       clsrRShimF := funcMakerInt(clsrRFi)             // as if this shim were 
registered
+       clsrRShim1x1F := reflectx.ToFunc1x1(clsrRShimF) // would be MakeFunc1x1 
if registered
+
+       // Parameters
        var a int
        var ai interface{} = a
        aV := reflect.ValueOf(a)
@@ -532,39 +544,57 @@ func BenchmarkMethodCalls(b *testing.B) {
                name string
                fn   func()
        }{
-               {"DirectMethod", func() { a = g.WhatsA(a) }}, // Baseline as 
low as we can go.
-               {"DirectFunc", func() { a = WhatsB(a) }},     // For comparison 
purposes
+               {"DirectMethod", func() { a = g.WhatsA(a) }},     // Baseline 
as low as we can go.
+               {"DirectFunc", func() { a = WhatsB(a) }},         // For 
comparison purposes
+               {"IndirectFunc", func() { a = indirectFunc(a) }}, // For 
comparison purposes
+
+               // Implicits
+               {"IndirectImplicit", func() { a = impRFi(a) }},             // 
Measures the indirection through reflect.Value cost.
+               {"TypeAssertedImplicit", func() { ai = impRFi(ai.(int)) }}, // 
Measures the type assertion cost over the above.
+
+               {"ReflectCallImplicit", func() { a = 
impRF.Call([]reflect.Value{reflect.ValueOf(a)})[0].Interface().(int) }},
+               {"ReflectCallImplicit-NoWrap", func() { a = 
impRF.Call([]reflect.Value{aV})[0].Interface().(int) }},
+               {"ReflectCallImplicit-NoReallocSlice", func() { a = 
impRF.Call(rvSlice)[0].Interface().(int) }},
+
+               {"ReflectXCallImplicit", func() { a = 
impRxF.Call([]interface{}{a})[0].(int) }},
+               {"ReflectXCallImplicit-NoReallocSlice", func() { a = 
impRxF.Call(efaceSlice)[0].(int) }},
+               {"ReflectXCall1x1Implicit", func() { a = 
impRx1x1F.Call1x1(a).(int) }}, // Measures the default shimfunc overhead.
+
+               {"ShimedCallImplicit", func() { a = 
impRShimF.Call([]interface{}{a})[0].(int) }},          // What we're currently 
using for invoking methods
+               {"ShimedCallImplicit-NoReallocSlice", func() { a = 
impRShimF.Call(efaceSlice)[0].(int) }}, // Closer to what we're using now.
+               {"ShimedCall1x1Implicit", func() { a = 
impRShim1x1F.Call1x1(a).(int) }},
 
-               {"IndirectFunc", func() { a = indirectFunc(a) }},         // 
For comparison purposes
-               {"IndirectImplicit", func() { a = nrFi(a) }},             // 
Measures the indirection through reflect.Value cost.
-               {"TypeAssertedImplicit", func() { ai = nrFi(ai.(int)) }}, // 
Measures the type assertion cost over the above.
+               // Explicit
+               {"IndirectExplicit", func() { a = expRFi(g, a) }},              
       // Measures the indirection through reflect.Value cost.
+               {"TypeAssertedExplicit", func() { ai = expRFi(gi.(*Foo), 
ai.(int)) }}, // Measures the type assertion cost over the above.
 
-               {"ReflectCallImplicit", func() { a = 
nrF.Call([]reflect.Value{reflect.ValueOf(a)})[0].Interface().(int) }},
-               {"ReflectCallImplicit-NoWrap", func() { a = 
nrF.Call([]reflect.Value{aV})[0].Interface().(int) }},
-               {"ReflectCallImplicit-NoReallocSlice", func() { a = 
nrF.Call(rvSlice)[0].Interface().(int) }},
+               {"ReflectCallExplicit", func() { a = 
expRF.Call([]reflect.Value{reflect.ValueOf(g), 
reflect.ValueOf(a)})[0].Interface().(int) }},
+               {"ReflectCallExplicit-NoWrap", func() { a = 
expRF.Call([]reflect.Value{gV, aV})[0].Interface().(int) }},
+               {"ReflectCallExplicit-NoReallocSlice", func() { a = 
expRF.Call(grvSlice)[0].Interface().(int) }},
 
-               {"ReflectXCallImplicit", func() { a = 
rxnrF.Call([]interface{}{a})[0].(int) }},
-               {"ReflectXCallImplicit-NoReallocSlice", func() { a = 
rxnrF.Call(efaceSlice)[0].(int) }},
-               {"ReflectXCall1x1Implicit", func() { a = 
rx0x1nrF.Call1x1(a).(int) }}, // Measures the default shimfunc overhead.
+               {"ReflectXCallExplicit", func() { a = 
expRxF.Call([]interface{}{g, a})[0].(int) }},
+               {"ReflectXCallExplicit-NoReallocSlice", func() { a = 
expRxF.Call(gEfaceSlice)[0].(int) }},
+               {"ReflectXCall2x1Explicit", func() { a = expRx2c1F.Call2x1(g, 
a).(int) }},
 
-               {"ShimedCallImplicit", func() { a = 
shimnrF.Call([]interface{}{a})[0].(int) }},          // What we're currently 
using for invoking methods
-               {"ShimedCallImplicit-NoReallocSlice", func() { a = 
shimnrF.Call(efaceSlice)[0].(int) }}, // Closer to what we're using now.
-               {"ShimedCall1x1Implicit", func() { a = 
shim0x1nrF.Call1x1(a).(int) }},
+               {"ShimedCallExplicit", func() { a = 
expRShimF.Call([]interface{}{g, a})[0].(int) }},
+               {"ShimedCallExplicit-NoReallocSlice", func() { a = 
expRShimF.Call(gEfaceSlice)[0].(int) }},
+               {"ShimedCall2x1Explicit", func() { a = expRShim2x1F.Call2x1(g, 
a).(int) }},
 
-               {"IndirectExplicit", func() { a = wrFi(g, a) }},                
     // Measures the indirection through reflect.Value cost.
-               {"TypeAssertedExplicit", func() { ai = wrFi(gi.(*Foo), 
ai.(int)) }}, // Measures the type assertion cost over the above.
+               // Closured
+               {"IndirectClosured", func() { a = clsrRFi(a) }},             // 
Measures the indirection through reflect.Value cost.
+               {"TypeAssertedClosured", func() { ai = clsrRFi(ai.(int)) }}, // 
Measures the type assertion cost over the above.
 
-               {"ReflectCallExplicit", func() { a = 
wrF.Call([]reflect.Value{reflect.ValueOf(g), 
reflect.ValueOf(a)})[0].Interface().(int) }},
-               {"ReflectCallExplicit-NoWrap", func() { a = 
wrF.Call([]reflect.Value{gV, aV})[0].Interface().(int) }},
-               {"ReflectCallExplicit-NoReallocSlice", func() { a = 
wrF.Call(grvSlice)[0].Interface().(int) }},
+               {"ReflectCallClosured", func() { a = 
clsrRF.Call([]reflect.Value{reflect.ValueOf(a)})[0].Interface().(int) }},
+               {"ReflectCallClosured-NoWrap", func() { a = 
clsrRF.Call([]reflect.Value{aV})[0].Interface().(int) }},
+               {"ReflectCallClosured-NoReallocSlice", func() { a = 
clsrRF.Call(rvSlice)[0].Interface().(int) }},
 
-               {"ReflectXCallExplicit", func() { a = rxF.Call([]interface{}{g, 
a})[0].(int) }},
-               {"ReflectXCallExplicit-NoReallocSlice", func() { a = 
rxF.Call(gEfaceSlice)[0].(int) }},
-               {"ReflectXCall2x1Explicit", func() { a = rx1x1F.Call2x1(g, 
a).(int) }},
+               {"ReflectXCallClosured", func() { a = 
clsrRxF.Call([]interface{}{a})[0].(int) }},
+               {"ReflectXCallClosured-NoReallocSlice", func() { a = 
clsrRxF.Call(efaceSlice)[0].(int) }},
+               {"ReflectXCall1x1Closured", func() { a = 
clsrRx1x1F.Call1x1(a).(int) }}, // Measures the default shimfunc overhead.
 
-               {"ShimedCallExplicit", func() { a = shimF.Call([]interface{}{g, 
a})[0].(int) }},
-               {"ShimedCallExplicit-NoReallocSlice", func() { a = 
shimF.Call(gEfaceSlice)[0].(int) }},
-               {"ShimedCall2x1Explicit", func() { a = shim1x1F.Call2x1(g, 
a).(int) }},
+               {"ShimedCallClosured", func() { a = 
clsrRShimF.Call([]interface{}{a})[0].(int) }},          // What we're currently 
using for invoking methods
+               {"ShimedCallClosured-NoReallocSlice", func() { a = 
clsrRShimF.Call(efaceSlice)[0].(int) }}, // Closer to what we're using now.
+               {"ShimedCall1x1Closured", func() { a = 
clsrRShim1x1F.Call1x1(a).(int) }},
        }
        for _, test := range tests {
                b.Run(test.name, func(b *testing.B) {
@@ -577,33 +607,51 @@ func BenchmarkMethodCalls(b *testing.B) {
 }
 
 /*
-@lostluck 2018/10/30 on a desktop machine.
-
-BenchmarkMethodCalls/DirectMethod-12                         1000000000        
         2.02 ns/op
-BenchmarkMethodCalls/DirectFunc-12                           2000000000        
         1.81 ns/op
-BenchmarkMethodCalls/IndirectFunc-12                         300000000         
 4.66 ns/op
-BenchmarkMethodCalls/IndirectImplicit-12                       10000000        
       185 ns/op
-BenchmarkMethodCalls/TypeAssertedImplicit-12                   10000000        
       228 ns/op
-BenchmarkMethodCalls/ReflectCallImplicit-12                     3000000        
       479 ns/op
-BenchmarkMethodCalls/ReflectCallImplicit-NoWrap-12              3000000        
       451 ns/op
-BenchmarkMethodCalls/ReflectCallImplicit-NoReallocSlice-12      3000000        
       424 ns/op
-BenchmarkMethodCalls/ReflectXCallImplicit-12                    2000000        
       756 ns/op
-BenchmarkMethodCalls/ReflectXCallImplicit-NoReallocSlice-12            2000000 
       662 ns/op **Default**
-BenchmarkMethodCalls/ReflectXCall1x1Implicit-12                 2000000        
       762 ns/op
-BenchmarkMethodCalls/ShimedCallImplicit-12                      5000000        
       374 ns/op
-BenchmarkMethodCalls/ShimedCallImplicit-NoReallocSlice-12       5000000        
       289 ns/op **With specialized shims**
-BenchmarkMethodCalls/ShimedCall1x1Implicit-12                   5000000        
       249 ns/op **Arity specialized re-work of the invoker**
-
-** Everything below requires an overhaul of structural DoFn invocation code, 
and regeneration of all included shims. **
-BenchmarkMethodCalls/IndirectExplicit-12                      300000000        
         4.81 ns/op
-BenchmarkMethodCalls/TypeAssertedExplicit-12                   50000000        
        35.4 ns/op
-BenchmarkMethodCalls/ReflectCallExplicit-12                     3000000        
       434 ns/op
-BenchmarkMethodCalls/ReflectCallExplicit-NoWrap-12              5000000        
       397 ns/op
-BenchmarkMethodCalls/ReflectCallExplicit-NoReallocSlice-12      5000000        
       390 ns/op
-BenchmarkMethodCalls/ReflectXCallExplicit-12                    2000000        
       755 ns/op
-BenchmarkMethodCalls/ReflectXCallExplicit-NoReallocSlice-12     2000000        
       601 ns/op
-BenchmarkMethodCalls/ReflectXCall2x1Explicit-12                 2000000        
       735 ns/op
-BenchmarkMethodCalls/ShimedCallExplicit-12                     10000000        
       198 ns/op
-BenchmarkMethodCalls/ShimedCallExplicit-NoReallocSlice-12      20000000        
        93.5 ns/op
-BenchmarkMethodCalls/ShimedCall2x1Explicit-12                  20000000        
        68.3 ns/op  **Best we could do**
+@lostluck 2018/11/29 on a Intel(R) Core(TM) i7-7Y75 CPU @ 1.30GHz Pixelbook.
+The actual times will vary per machine, but the relative differences are 
unlikely to change
+between machines.
+
+*** Standard candles **
+BenchmarkMethodCalls/DirectMethod-4             1000000000               2.00 
ns/op
+BenchmarkMethodCalls/DirectFunc-4               2000000000               1.87 
ns/op
+BenchmarkMethodCalls/IndirectFunc-4             300000000                4.68 
ns/op
+
+*** Implicit receiver variants **
+BenchmarkMethodCalls/IndirectImplicit-4         10000000               187 
ns/op
+BenchmarkMethodCalls/TypeAssertedImplicit-4     10000000               207 
ns/op
+BenchmarkMethodCalls/ReflectCallImplicit-4       5000000               362 
ns/op
+BenchmarkMethodCalls/ReflectCallImplicit-NoWrap-4                5000000       
        350 ns/op
+BenchmarkMethodCalls/ReflectCallImplicit-NoReallocSlice-4        5000000       
        365 ns/op
+BenchmarkMethodCalls/ReflectXCallImplicit-4                      2000000       
        874 ns/op
+BenchmarkMethodCalls/ReflectXCallImplicit-NoReallocSlice-4               
2000000              1227 ns/op **Default**
+BenchmarkMethodCalls/ReflectXCall1x1Implicit-4                           
1000000              1184 ns/op
+BenchmarkMethodCalls/ShimedCallImplicit-4                                
2000000               647 ns/op
+BenchmarkMethodCalls/ShimedCallImplicit-NoReallocSlice-4                 
3000000               589 ns/op
+BenchmarkMethodCalls/ShimedCall1x1Implicit-4                             
3000000               446 ns/op
+
+*** Explicit receiver variants ***
+BenchmarkMethodCalls/IndirectExplicit-4                                 
200000000                7.64 ns/op
+BenchmarkMethodCalls/TypeAssertedExplicit-4                             
50000000                26.9 ns/op
+BenchmarkMethodCalls/ReflectCallExplicit-4                               
3000000               430 ns/op
+BenchmarkMethodCalls/ReflectCallExplicit-NoWrap-4                        
5000000               394 ns/op
+BenchmarkMethodCalls/ReflectCallExplicit-NoReallocSlice-4                
3000000               375 ns/op
+BenchmarkMethodCalls/ReflectXCallExplicit-4                              
2000000               621 ns/op
+BenchmarkMethodCalls/ReflectXCallExplicit-NoReallocSlice-4               
3000000               552 ns/op
+BenchmarkMethodCalls/ReflectXCall2x1Explicit-4                           
2000000               839 ns/op
+BenchmarkMethodCalls/ShimedCallExplicit-4                                
5000000               208 ns/op
+BenchmarkMethodCalls/ShimedCallExplicit-NoReallocSlice-4                
20000000                70.1 ns/op
+BenchmarkMethodCalls/ShimedCall2x1Explicit-4                            
30000000                48.9 ns/op
+
+*** Closured direct invocation variants ***
+BenchmarkMethodCalls/IndirectClosured-4                                 
300000000                4.93 ns/op
+BenchmarkMethodCalls/TypeAssertedClosured-4                             
100000000               25.7 ns/op
+BenchmarkMethodCalls/ReflectCallClosured-4                               
5000000               318 ns/op
+BenchmarkMethodCalls/ReflectCallClosured-NoWrap-4                        
5000000               269 ns/op
+BenchmarkMethodCalls/ReflectCallClosured-NoReallocSlice-4                
5000000               266 ns/op
+BenchmarkMethodCalls/ReflectXCallClosured-4                              
3000000               440 ns/op
+BenchmarkMethodCalls/ReflectXCallClosured-NoReallocSlice-4               
5000000               377 ns/op
+BenchmarkMethodCalls/ReflectXCall1x1Closured-4                           
3000000               460 ns/op
+BenchmarkMethodCalls/ShimedCallClosured-4                               
20000000               113 ns/op
+BenchmarkMethodCalls/ShimedCallClosured-NoReallocSlice-4                
20000000                61.5 ns/op **With specialized shims**
+BenchmarkMethodCalls/ShimedCall1x1Closured-4                            
30000000                45.5 ns/op **Arity specialized re-work of the invoker**
 */
diff --git a/sdks/go/pkg/beam/core/util/reflectx/call.go 
b/sdks/go/pkg/beam/core/util/reflectx/call.go
index 9a8df8471d18..36861ebd2178 100644
--- a/sdks/go/pkg/beam/core/util/reflectx/call.go
+++ b/sdks/go/pkg/beam/core/util/reflectx/call.go
@@ -55,7 +55,7 @@ func RegisterFunc(t reflect.Type, maker func(interface{}) 
Func) {
 
        key := t.String()
        if _, exists := funcs[key]; exists {
-               log.Warnf(context.Background(), "Func for %v already 
registered. Overwriting.", key)
+               log.Debugf(context.Background(), "Func for %v already 
registered. Overwriting.", key)
        }
        funcs[key] = maker
 }
diff --git a/sdks/go/pkg/beam/core/util/reflectx/structs.go 
b/sdks/go/pkg/beam/core/util/reflectx/structs.go
new file mode 100644
index 000000000000..d5aef1bde8f5
--- /dev/null
+++ b/sdks/go/pkg/beam/core/util/reflectx/structs.go
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reflectx
+
+import (
+       "context"
+       "reflect"
+       "sync"
+
+       "github.com/apache/beam/sdks/go/pkg/beam/log"
+)
+
+var (
+       structFuncs   = make(map[string]func(interface{}) map[string]Func)
+       structFuncsMu sync.Mutex
+)
+
+// RegisterStructWrapper takes in the reflect.Type of a structural DoFn, and
+// a wrapping function that will take an instance of that struct type and
+// produce a map of method names to of closured Funcs that call the method
+// on the instance of the struct.
+//
+// The goal is to avoid the implicit reflective method invocation penalty
+// that occurs when passing a method through the reflect package.
+func RegisterStructWrapper(t reflect.Type, wrapper func(interface{}) 
map[string]Func) {
+       structFuncsMu.Lock()
+       defer structFuncsMu.Unlock()
+
+       if t.Kind() != reflect.Struct {
+               log.Fatalf(context.Background(), "RegisterStructWrapper for %v 
should be a struct type, but was %v", t, t.Kind())
+       }
+
+       key := t.String()
+       if _, exists := funcs[key]; exists {
+               log.Warnf(context.Background(), "StructWrapper for %v already 
registered. Overwriting.", key)
+       }
+       structFuncs[key] = wrapper
+}
+
+// WrapMethods takes in a struct value as an interface, and returns a map of
+// method names to Funcs of those methods wrapped in a closure for the struct 
instance.
+func WrapMethods(fn interface{}) (map[string]Func, bool) {
+       return wrapMethodsKeyed(reflect.TypeOf(fn), fn)
+}
+
+// WrapMethodsKeyed takes in a struct value as an interface
+func wrapMethodsKeyed(t reflect.Type, fn interface{}) (map[string]Func, bool) {
+       structFuncsMu.Lock()
+       defer structFuncsMu.Unlock()
+       // Registering happens on the value, not the proto type.
+       if t.Kind() == reflect.Ptr {
+               t = t.Elem()
+       }
+       key := t.String()
+       if f, exists := structFuncs[key]; exists {
+               log.Debugf(context.Background(), "EXTRACTING StructWrapper for 
%v", key)
+               return f(fn), true
+       }
+       return nil, false
+}
diff --git a/sdks/go/pkg/beam/util/shimx/generate.go 
b/sdks/go/pkg/beam/util/shimx/generate.go
index 6c0eb4a231e4..31b70e54fa77 100644
--- a/sdks/go/pkg/beam/util/shimx/generate.go
+++ b/sdks/go/pkg/beam/util/shimx/generate.go
@@ -69,6 +69,7 @@ type Top struct {
        Imports   []string // the full import paths
        Functions []string // the plain names of the functions to be registered.
        Types     []string // the plain names of the types to be registered.
+       Wraps     []Wrap
        Emitters  []Emitter
        Inputs    []Input
        Shims     []Func
@@ -155,6 +156,13 @@ type Func struct {
        In, Out    []string
 }
 
+// Wrap represents a type assertion shim for Structural DoFn method
+// invocation to be generated.
+type Wrap struct {
+       Name, Type string
+       Methods    []Func
+}
+
 // Name creates a capitalized identifier from a type string. The identifier
 // follows the rules of go identifiers and should be compileable.
 // See https://golang.org/ref/spec#Identifiers for details.
@@ -217,6 +225,9 @@ func init() {
 {{- range $x := .Types}}
        runtime.RegisterType(reflect.TypeOf((*{{$x}})(nil)).Elem())
 {{- end}}
+{{- range $x := .Wraps}}
+       
reflectx.RegisterStructWrapper(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), 
wrapMaker{{$x.Name}})
+{{- end}}
 {{- range $x  := .Shims}}
        reflectx.RegisterFunc(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), 
funcMaker{{$x.Name}})
 {{- end}}
@@ -228,7 +239,18 @@ func init() {
 {{- end}}
 }
 
-{{range $x  := .Shims -}}
+{{range $x := .Wraps -}}
+func wrapMaker{{$x.Name}}(fn interface{}) map[string]reflectx.Func {
+       dfn := fn.(*{{$x.Type}})
+       return map[string]reflectx.Func{
+       {{- range $y := .Methods}}
+               "{{$y.Name}}": reflectx.MakeFunc(func({{mkparams "a%d %v" 
$y.In}}) {{if $y.Out}}({{mkrets "%v" $y.Out}}) { return {{else -}} { {{end -}} 
dfn.{{$y.Name}}({{mktuplef (len $y.In) "a%d" }}) }),
+       {{- end}}
+       }
+}
+
+{{end}}
+{{- range $x  := .Shims -}}
 type caller{{$x.Name}} struct {
        fn {{$x.Type}}
 }
@@ -256,7 +278,7 @@ func (c *caller{{$x.Name}}) Call{{len $x.In}}x{{len 
$x.Out}}({{mkargs (len $x.In
 }
 
 {{end}}
-{{if .Emitters -}}
+{{- if .Emitters -}}
 type emitNative struct {
        n     exec.ElementProcessor
        fn    interface{}
@@ -277,8 +299,8 @@ func (e *emitNative) Value() interface{} {
        return e.fn
 }
 
-{{end -}}
-{{range $x := .Emitters -}}
+{{end}}
+{{- range $x := .Emitters -}}
 func emitMaker{{$x.Name}}(n exec.ElementProcessor) exec.ReusableEmitter {
        ret := &emitNative{n: n}
        ret.fn = ret.invoke{{.Name}}
@@ -331,8 +353,9 @@ func (v *iterNative) Reset() error {
        v.cur = nil
        return nil
 }
-{{- end}}
-{{- range $x := .Inputs}}
+
+{{end}}
+{{- range $x := .Inputs -}}
 func iterMaker{{$x.Name}}(s exec.ReStream) exec.ReusableInput {
        ret := &iterNative{s: s}
        ret.fn = ret.read{{$x.Name}}
@@ -363,8 +386,8 @@ func (v *iterNative) read{{$x.Name}}({{if $x.Time -}} et 
*typex.EventTime, {{end
 {{- end}}
        return true
 }
-{{- end}}
 
+{{end}}
 // DO NOT MODIFY: GENERATED CODE
 `))
 
@@ -372,6 +395,7 @@ func (v *iterNative) read{{$x.Name}}({{if $x.Time -}} et 
*typex.EventTime, {{end
 var funcMap template.FuncMap = map[string]interface{}{
        "mkargs":   mkargs,
        "mkparams": mkparams,
+       "mkrets":   mkrets,
        "mktuple":  mktuple,
        "mktuplef": mktuplef,
 }
@@ -385,7 +409,7 @@ func mkargs(n int, format, typ string) string {
        return fmt.Sprintf("%v %v", mktuplef(n, format), typ)
 }
 
-// mkparams(format, []type) returns "<fmt.Sprintf(format, 0, type[0])>, .., 
<fmt.Sprintf(format, n-1, type[0])>".
+// mkparams(format, []type) returns "<fmt.Sprintf(format, 0, type[0])>, .., 
<fmt.Sprintf(format, n-1, type[n-1])>".
 func mkparams(format string, types []string) string {
        var ret []string
        for i, t := range types {
@@ -394,6 +418,15 @@ func mkparams(format string, types []string) string {
        return strings.Join(ret, ", ")
 }
 
+// mkrets(format, []type) returns "<fmt.Sprintf(format, type[0])>, .., 
<fmt.Sprintf(format, type[n-1])>".
+func mkrets(format string, types []string) string {
+       var ret []string
+       for _, t := range types {
+               ret = append(ret, fmt.Sprintf(format, t))
+       }
+       return strings.Join(ret, ", ")
+}
+
 // mktuple(n, v) returns "v, v, ..., v".
 func mktuple(n int, v string) string {
        var ret []string
diff --git a/sdks/go/pkg/beam/util/shimx/generate_test.go 
b/sdks/go/pkg/beam/util/shimx/generate_test.go
index 3696bbab7f32..43058671e559 100644
--- a/sdks/go/pkg/beam/util/shimx/generate_test.go
+++ b/sdks/go/pkg/beam/util/shimx/generate_test.go
@@ -215,3 +215,20 @@ func TestMktuplef(t *testing.T) {
                }
        }
 }
+
+func TestMkrets(t *testing.T) {
+       tests := []struct {
+               types  []string
+               format string
+               want   string
+       }{
+               {types: nil, format: "%v", want: ""},
+               {types: []string{}, format: "%v", want: ""},
+               {types: []string{"Foo", "baz", "*imp.Bar"}, format: "%v", want: 
"Foo, baz, *imp.Bar"},
+       }
+       for _, test := range tests {
+               if got := mkrets(test.format, test.types); got != test.want {
+                       t.Errorf("mkrets(%v,%v) = %v, want %v", test.format, 
test.types, got, test.want)
+               }
+       }
+}
diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go 
b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go
index 003a3c91df71..23a95d6dd7e0 100644
--- a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go
+++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go
@@ -38,6 +38,7 @@ func NewExtractor(pkg string) *Extractor {
                Package:     pkg,
                functions:   make(map[string]struct{}),
                types:       make(map[string]struct{}),
+               wraps:       make(map[string]map[string]*types.Signature),
                funcs:       make(map[string]*types.Signature),
                emits:       make(map[string]shimx.Emitter),
                iters:       make(map[string]shimx.Input),
@@ -60,6 +61,8 @@ type Extractor struct {
        functions map[string]struct{}
        // Types to Register (structs, essentially)
        types map[string]struct{}
+       // StructuralDoFn wraps needed (receiver type, then method names)
+       wraps map[string]map[string]*types.Signature
        // FuncShims needed
        funcs map[string]*types.Signature
        // Emitter Shims needed
@@ -81,6 +84,7 @@ func (e *Extractor) Summary() {
        e.Printf("All exported?: %v\n", e.allExported)
        e.Printf("%d\t Functions\n", len(e.functions))
        e.Printf("%d\t Types\n", len(e.types))
+       e.Printf("%d\t Wraps\n", len(e.wraps))
        e.Printf("%d\t Shims\n", len(e.funcs))
        e.Printf("%d\t Emits\n", len(e.emits))
        e.Printf("%d\t Inputs\n", len(e.iters))
@@ -257,6 +261,20 @@ func (e *Extractor) fromObj(fset *token.FileSet, id 
*ast.Ident, obj types.Object
                                // If this is not a lifecycle method, we should 
ignore it.
                                return
                        }
+                       // This must be a structural DoFn! We should generate a 
closure wrapper for it.
+                       t := recv.Type()
+                       p, ok := t.(*types.Pointer)
+                       for ok {
+                               t = p.Elem()
+                               p, ok = t.(*types.Pointer)
+                       }
+                       ts := types.TypeString(t, e.qualifier)
+                       mthdMap := e.wraps[ts]
+                       if mthdMap == nil {
+                               mthdMap = make(map[string]*types.Signature)
+                               e.wraps[ts] = mthdMap
+                       }
+                       mthdMap[id.Name] = sig
                } else if id.Name != "init" {
                        // init functions are special and should be ignored.
                        // Functions need registering, as well as shim 
generation.
@@ -385,24 +403,20 @@ func (e *Extractor) Generate(filename string) []byte {
        for t := range e.types {
                typs = append(typs, t)
        }
+       var wraps []shimx.Wrap
+       for typ, mthdMap := range e.wraps {
+               wrap := shimx.Wrap{Type: typ, Name: shimx.Name(typ)}
+               for mName, mthd := range mthdMap {
+                       shim := e.makeFunc(mthd)
+                       shim.Name = mName
+                       wrap.Methods = append(wrap.Methods, shim)
+               }
+               wraps = append(wraps, wrap)
+       }
        var shims []shimx.Func
        for sig, t := range e.funcs {
-               shim := shimx.Func{Type: sig}
-               var inNames []string
-               in := t.Params() // *types.Tuple
-               for i := 0; i < in.Len(); i++ {
-                       s := in.At(i) // *types.Var
-                       shim.In = append(shim.In, types.TypeString(s.Type(), 
e.qualifier))
-                       inNames = append(inNames, e.NameType(s.Type()))
-               }
-               var outNames []string
-               out := t.Results() // *types.Tuple
-               for i := 0; i < out.Len(); i++ {
-                       s := out.At(i)
-                       shim.Out = append(shim.Out, types.TypeString(s.Type(), 
e.qualifier))
-                       outNames = append(outNames, e.NameType(s.Type()))
-               }
-               shim.Name = shimx.FuncName(inNames, outNames)
+               shim := e.makeFunc(t)
+               shim.Type = sig
                shims = append(shims, shim)
        }
        var emits []shimx.Emitter
@@ -429,6 +443,7 @@ func (e *Extractor) Generate(filename string) []byte {
                Imports:   imports,
                Functions: functions,
                Types:     typs,
+               Wraps:     wraps,
                Shims:     shims,
                Emitters:  emits,
                Inputs:    inputs,
@@ -438,6 +453,26 @@ func (e *Extractor) Generate(filename string) []byte {
        return e.w.Bytes()
 }
 
+func (e *Extractor) makeFunc(t *types.Signature) shimx.Func {
+       shim := shimx.Func{}
+       var inNames []string
+       in := t.Params() // *types.Tuple
+       for i := 0; i < in.Len(); i++ {
+               s := in.At(i) // *types.Var
+               shim.In = append(shim.In, types.TypeString(s.Type(), 
e.qualifier))
+               inNames = append(inNames, e.NameType(s.Type()))
+       }
+       var outNames []string
+       out := t.Results() // *types.Tuple
+       for i := 0; i < out.Len(); i++ {
+               s := out.At(i)
+               shim.Out = append(shim.Out, types.TypeString(s.Type(), 
e.qualifier))
+               outNames = append(outNames, e.NameType(s.Type()))
+       }
+       shim.Name = shimx.FuncName(inNames, outNames)
+       return shim
+}
+
 func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) {
        // Emitters must have no return values.
        if sig.Results().Len() != 0 {
@@ -540,7 +575,7 @@ func (e *Extractor) varString(v *types.Var) string {
        return types.TypeString(v.Type(), e.qualifier)
 }
 
-// NameType turns a reflect.Type into a strying based on it's name.
+// NameType turns a reflect.Type into a string based on it's name.
 // It prefixes Emit or Iter if the function satisfies the constrains of those 
types.
 func (e *Extractor) NameType(t types.Type) string {
        switch a := t.(type) {
diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go 
b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go
index 9141acb114eb..4b437c1226f2 100644
--- a/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go
+++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go
@@ -43,8 +43,8 @@ func TestExtractor(t *testing.T) {
                        expected: []string{"runtime.RegisterFunction(iterFn)", 
"funcMakerStringIterIntГ", "iterMakerInt"},
                },
                {name: "structs1", files: []string{structs}, pkg: "structs", 
ids: []string{"myDoFn"},
-                       expected: 
[]string{"runtime.RegisterType(reflect.TypeOf((*myDoFn)(nil)).Elem())", 
"funcMakerEmitIntГ", "emitMakerInt", "funcMakerValTypeValTypeEmitIntГ", 
"runtime.RegisterType(reflect.TypeOf((*valType)(nil)).Elem())"},
-                       excluded: []string{"funcMakerStringГ", 
"emitMakerString", "nonPipelineType"},
+                       expected: 
[]string{"runtime.RegisterType(reflect.TypeOf((*myDoFn)(nil)).Elem())", 
"funcMakerEmitIntГ", "emitMakerInt", "funcMakerValTypeValTypeEmitIntГ", 
"runtime.RegisterType(reflect.TypeOf((*valType)(nil)).Elem())", 
"reflectx.RegisterStructWrapper(reflect.TypeOf((*myDoFn)(nil)).Elem(), 
wrapMakerMyDoFn)"},
+                       excluded: []string{"funcMakerStringГ", 
"emitMakerString", "nonPipelineType", "UnrelatedMethod1", "UnrelatedMethod2", 
"UnrelatedMethod3"},
                },
        }
        for _, test := range tests {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 170763)
    Time Spent: 7h 20m  (was: 7h 10m)

> Make it easy to generate type-specialized Go SDK reflectx.Funcs
> ---------------------------------------------------------------
>
>                 Key: BEAM-3612
>                 URL: https://issues.apache.org/jira/browse/BEAM-3612
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-go
>            Reporter: Henning Rohde
>            Assignee: Robert Burke
>            Priority: Major
>          Time Spent: 7h 20m
>  Remaining Estimate: 0h
>




--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to