[
https://issues.apache.org/jira/browse/BEAM-3612?focusedWorklogId=165044&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-165044
]
ASF GitHub Bot logged work on BEAM-3612:
----------------------------------------
Author: ASF GitHub Bot
Created on: 12/Nov/18 18:03
Start Date: 12/Nov/18 18:03
Worklog Time Spent: 10m
Work Description: aaltay closed pull request #7000: [BEAM-3612] Add a
shim generator tool
URL: https://github.com/apache/beam/pull/7000
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/sdks/go/cmd/starcgen/starcgen.go b/sdks/go/cmd/starcgen/starcgen.go
new file mode 100644
index 00000000000..87e80110b39
--- /dev/null
+++ b/sdks/go/cmd/starcgen/starcgen.go
@@ -0,0 +1,154 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// starcgen is a tool to generate specialized type assertion shims to be
+// used in Apache Beam Go SDK pipelines instead of the default reflection shim.
+// This is done through static analysis of go sources for the package in
question.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/importer"
+ "go/parser"
+ "go/token"
+ "io"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/util/starcgenx"
+)
+
+var (
+ inputs = flag.String("inputs", "", "comma separated list of file with
types to create")
+ output = flag.String("output", "", "output file with types to create")
+ ids = flag.String("identifiers", "", "comma separated list of
package local identifiers for which to generate code")
+)
+
+// Generate takes the typechecked inputs, and generates the shim file for the
relevant
+// identifiers.
+func Generate(w io.Writer, filename, pkg string, ids []string, fset
*token.FileSet, files []*ast.File) error {
+ e := starcgenx.NewExtractor(pkg)
+ e.Ids = ids
+
+ // Importing from source should work in most cases.
+ imp := importer.For("source", nil)
+ if err := e.FromAsts(imp, fset, files); err != nil {
+ // Always print out the debugging info to the file.
+ if _, errw := w.Write(e.Bytes()); errw != nil {
+ return fmt.Errorf("error writing debug data to file
after err %v:%v", err, errw)
+ }
+ return fmt.Errorf("error extracting from asts: %v", err)
+ }
+
+ e.Print("*/\n")
+ data := e.Generate(filename)
+ if err := write(w, []byte(license)); err != nil {
+ return err
+ }
+ return write(w, data)
+}
+
+func write(w io.Writer, data []byte) error {
+ n, err := w.Write(data)
+ if err != nil && n < len(data) {
+ return fmt.Errorf("short write of data got %d, want %d", n,
len(data))
+ }
+ return err
+}
+
+func usage() {
+ fmt.Fprintf(os.Stderr, "Usage: %v [options] --inputs=<comma separated
of go files>\n", filepath.Base(os.Args[0]))
+ flag.PrintDefaults()
+}
+
+func main() {
+ flag.Usage = usage
+ flag.Parse()
+
+ log.SetFlags(log.Lshortfile)
+ log.SetPrefix("starcgen: ")
+
+ ipts := strings.Split(*inputs, ",")
+ fset := token.NewFileSet()
+ var fs []*ast.File
+ var pkg string
+
+ dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ for _, i := range ipts {
+ f, err := parser.ParseFile(fset, i, nil, 0)
+ if err != nil {
+ err1 := err
+ f, err = parser.ParseFile(fset, filepath.Join(dir, i),
nil, 0)
+ if err != nil {
+ log.Print(err1)
+ log.Fatal(err) // parse error
+ }
+ }
+
+ if pkg == "" {
+ pkg = f.Name.Name
+ } else if pkg != f.Name.Name {
+ log.Fatalf("Input file %v has mismatched package path,
got %q, want %q", i, f.Name.Name, pkg)
+ }
+ fs = append(fs, f)
+ }
+ if pkg == "" {
+ log.Fatalf("No package detected in input files: %v", inputs)
+ }
+
+ if *output == "" {
+ name := pkg
+ if len(ipts) == 1 {
+ name = filepath.Base(ipts[0])
+ if index := strings.Index(name, "."); index > 0 {
+ name = name[:index]
+ }
+ }
+ *output = filepath.Join(filepath.Dir(ipts[0]), name+".shims.go")
+ }
+
+ f, err := os.OpenFile(*output, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if err := Generate(f, *output, pkg, strings.Split(*ids, ","), fset,
fs); err != nil {
+ log.Fatal(err)
+ }
+}
+
+const license = `// Licensed to the Apache Software Foundation (ASF) under one
or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+`
diff --git a/sdks/go/cmd/starcgen/starcgen_test.go
b/sdks/go/cmd/starcgen/starcgen_test.go
new file mode 100644
index 00000000000..7282ada8a27
--- /dev/null
+++ b/sdks/go/cmd/starcgen/starcgen_test.go
@@ -0,0 +1,123 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "strings"
+ "testing"
+)
+
+func TestGenerate(t *testing.T) {
+ tests := []struct {
+ name string
+ pkg string
+ files []string
+ ids []string
+ expected []string
+ excluded []string
+ }{
+ {name: "genAllSingleFile", files: []string{hello1}, pkg:
"hello", ids: []string{},
+ expected: []string{"runtime.RegisterFunction(MyTitle)",
"runtime.RegisterFunction(MyOtherDoFn)",
"runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())",
"funcMakerContext۰ContextStringГString", "funcMakerFooГString"},
+ },
+ {name: "genSpecificSingleFile", files: []string{hello1}, pkg:
"hello", ids: []string{"MyTitle"},
+ expected: []string{"runtime.RegisterFunction(MyTitle)",
"funcMakerContext۰ContextStringГString"},
+ excluded: []string{"MyOtherDoFn",
"runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())",
"funcMakerFooГString"},
+ },
+ {name: "genAllMultiFile", files: []string{hello1, hello2}, pkg:
"hello", ids: []string{},
+ expected: []string{"runtime.RegisterFunction(MyTitle)",
"runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterFunction(anotherFn)",
"runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())",
"funcMakerContext۰ContextStringГString", "funcMakerFooГString",
"funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГFoo"},
+ },
+ {name: "genSpecificMultiFile1", files: []string{hello1,
hello2}, pkg: "hello", ids: []string{"MyTitle"},
+ expected: []string{"runtime.RegisterFunction(MyTitle)",
"funcMakerContext۰ContextStringГString"},
+ excluded: []string{"MyOtherDoFn", "anotherFn",
"runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())",
"funcMakerFooГString", "funcMakerShimx۰EmitterГString",
"funcMakerShimx۰EmitterГFoo"},
+ },
+ {name: "genSpecificMultiFile2", files: []string{hello1,
hello2}, pkg: "hello", ids: []string{"anotherFn"},
+ expected: []string{"funcMakerShimx۰EmitterГString",
"funcMakerShimx۰EmitterГString"},
+ excluded: []string{"MyOtherDoFn", "MyTitle",
"runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())",
"funcMakerFooГString"},
+ },
+ }
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ fset := token.NewFileSet()
+ var fs []*ast.File
+ for i, f := range test.files {
+ n, err := parser.ParseFile(fset, "", f, 0)
+ if err != nil {
+ t.Fatalf("couldn't parse
test.files[%d]: %v", i, err)
+ }
+ fs = append(fs, n)
+ }
+ var b bytes.Buffer
+ if err := Generate(&b, test.name+".go", test.pkg,
test.ids, fset, fs); err != nil {
+ t.Fatal(err)
+ }
+ s := string(b.Bytes())
+ for _, i := range test.expected {
+ if !strings.Contains(s, i) {
+ t.Errorf("expected %q in generated
file", i)
+ }
+ }
+ for _, i := range test.excluded {
+ if strings.Contains(s, i) {
+ t.Errorf("found %q in generated file",
i)
+ }
+ }
+ t.Log(s)
+ })
+ }
+}
+
+const hello1 = `
+package hello
+
+import (
+ "context"
+ "strings"
+)
+
+func MyTitle(ctx context.Context, v string) string {
+ return strings.Title(v)
+}
+
+type foo struct{}
+
+func MyOtherDoFn(v foo) string {
+ return "constant"
+}
+`
+
+const hello2 = `
+package hello
+
+import (
+ "context"
+ "strings"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/util/shimx"
+)
+
+func anotherFn(v shimx.Emitter) string {
+ return v.Name
+}
+
+func fooFn(v shimx.Emitter) foo {
+ return foo{}
+}
+`
diff --git a/sdks/go/pkg/beam/util/shimx/generate.go
b/sdks/go/pkg/beam/util/shimx/generate.go
new file mode 100644
index 00000000000..6c0eb4a231e
--- /dev/null
+++ b/sdks/go/pkg/beam/util/shimx/generate.go
@@ -0,0 +1,413 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package shimx specifies the templates for generating type assertion shims
for
+// Apache Beam Go SDK pipelines.
+//
+// In particular, the shims are used by the Beam Go SDK to avoid reflection at
runtime,
+// which is the default mode of operation. The shims are specialized for the
code
+// in question, using type assertion to convert arguments as required, and
invoke the
+// user code.
+//
+// Similar shims are required for emitters, and iterators in order to
propagate values
+// out of, and in to user functions respectively without reflection overhead.
+//
+// Registering user types is required to support user types as PCollection
+// types, while registering functions is required to avoid possibly expensive
function
+// resolution at worker start up, which defaults to using DWARF Symbol tables.
+//
+// The generator largely relies on basic types and strings to ensure that it's
usable
+// by dynamic processes via reflection, or by any static analysis approach
that is
+// used in the future.
+package shimx
+
+import (
+ "fmt"
+ "io"
+ "sort"
+ "strings"
+ "text/template"
+)
+
+// Beam imports that the generated code requires.
+var (
+ ExecImport =
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/exec"
+ TypexImport = "github.com/apache/beam/sdks/go/pkg/beam/core/typex"
+ ReflectxImport =
"github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
+ RuntimeImport = "github.com/apache/beam/sdks/go/pkg/beam/core/runtime"
+)
+
+func validateBeamImports() {
+ checkImportSuffix(ExecImport, "exec")
+ checkImportSuffix(TypexImport, "typex")
+ checkImportSuffix(ReflectxImport, "reflectx")
+ checkImportSuffix(RuntimeImport, "runtime")
+}
+
+func checkImportSuffix(path, suffix string) {
+ if !strings.HasSuffix(path, suffix) {
+ panic(fmt.Sprintf("expected %v to end with %v. can't generate
valid code", path, suffix))
+ }
+}
+
+// Top is the top level inputs into the template file for generating shims.
+type Top struct {
+ FileName, ToolName, Package string
+
+ Imports []string // the full import paths
+ Functions []string // the plain names of the functions to be registered.
+ Types []string // the plain names of the types to be registered.
+ Emitters []Emitter
+ Inputs []Input
+ Shims []Func
+}
+
+// sort orders the shims consistently to minimize diffs in the generated code.
+func (t *Top) sort() {
+ sort.Strings(t.Imports)
+ sort.Strings(t.Functions)
+ sort.Strings(t.Types)
+ sort.SliceStable(t.Emitters, func(i, j int) bool {
+ return t.Emitters[i].Name < t.Emitters[j].Name
+ })
+ sort.SliceStable(t.Inputs, func(i, j int) bool {
+ return t.Inputs[i].Name < t.Inputs[j].Name
+ })
+ sort.SliceStable(t.Shims, func(i, j int) bool {
+ return t.Shims[i].Name < t.Shims[j].Name
+ })
+}
+
+// processImports removes imports that are otherwise handled by the template
+// This method is on the value to shallow copy the Field references to avoid
+// mutating the user provided instance.
+func (t Top) processImports() *Top {
+ pred := map[string]bool{"reflect": true}
+ var filtered []string
+ if len(t.Emitters) > 0 {
+ pred["context"] = true
+ }
+ if len(t.Inputs) > 0 {
+ pred["fmt"] = true
+ pred["io"] = true
+ }
+ if len(t.Types) > 0 || len(t.Functions) > 0 {
+ filtered = append(filtered, RuntimeImport)
+ pred[RuntimeImport] = true
+ }
+ if len(t.Shims) > 0 {
+ filtered = append(filtered, ReflectxImport)
+ pred[ReflectxImport] = true
+ }
+ if len(t.Emitters) > 0 || len(t.Inputs) > 0 {
+ filtered = append(filtered, ExecImport)
+ pred[ExecImport] = true
+ }
+ needTypexImport := len(t.Emitters) > 0
+ for _, i := range t.Inputs {
+ if i.Time {
+ needTypexImport = true
+ break
+ }
+ }
+ if needTypexImport {
+ filtered = append(filtered, TypexImport)
+ pred[TypexImport] = true
+ }
+ for _, imp := range t.Imports {
+ if !pred[imp] {
+ filtered = append(filtered, imp)
+ }
+ }
+ t.Imports = filtered
+ return &t
+}
+
+// Emitter represents an emitter shim to be generated.
+type Emitter struct {
+ Name, Type string // The user name of the function, the type of the
emit.
+ Time bool // if this uses event time.
+ Key, Val string // The type of the emits.
+}
+
+// Input represents an iterator shim to be generated.
+type Input struct {
+ Name, Type string // The user name of the function, the type of the
iterator (including the bool).
+ Time bool // if this uses event time.
+ Key, Val string // The type of the inputs, pointers removed.
+}
+
+// Func represents a type assertion shim for function invocation to be
generated.
+type Func struct {
+ Name, Type string
+ In, Out []string
+}
+
+// Name creates a capitalized identifier from a type string. The identifier
+// follows the rules of go identifiers and should be compileable.
+// See https://golang.org/ref/spec#Identifiers for details.
+func Name(t string) string {
+ if strings.HasPrefix(t, "[]") {
+ return Name(t[2:] + "Slice")
+ }
+
+ t = strings.Replace(t, "beam.", "typex.", -1)
+ t = strings.Replace(t, ".", "۰", -1) // For packages
+ t = strings.Replace(t, "*", "Ꮨ", -1) // For pointers
+ t = strings.Replace(t, "[", "_", -1) // For maps
+ t = strings.Replace(t, "]", "_", -1)
+ return strings.Title(t)
+}
+
+// FuncName returns a compilable Go identifier for a function, given valid
+// type names as generated by Name.
+// See https://golang.org/ref/spec#Identifiers for details.
+func FuncName(inNames, outNames []string) string {
+ return fmt.Sprintf("%sГ%s", strings.Join(inNames, ""),
strings.Join(outNames, ""))
+}
+
+// File writes go code to the given writer.
+func File(w io.Writer, top *Top) {
+ validateBeamImports()
+ top = top.processImports()
+ top.sort()
+ vampireTemplate.Funcs(funcMap).Execute(w, top)
+}
+
+var vampireTemplate =
template.Must(template.New("vampire").Funcs(funcMap).Parse(`// Code generated
by {{.ToolName}}. DO NOT EDIT.
+// File: {{.FileName}}
+
+package {{.Package}}
+
+import (
+
+{{- if .Emitters}}
+ "context"
+{{- end}}
+{{- if .Inputs}}
+ "fmt"
+ "io"
+{{- end}}
+ "reflect"
+{{- if .Imports}}
+
+ // Library imports
+{{- end}}
+{{- range $import := .Imports}}
+ "{{$import}}"
+{{- end}}
+)
+
+func init() {
+{{- range $x := .Functions}}
+ runtime.RegisterFunction({{$x}})
+{{- end}}
+{{- range $x := .Types}}
+ runtime.RegisterType(reflect.TypeOf((*{{$x}})(nil)).Elem())
+{{- end}}
+{{- range $x := .Shims}}
+ reflectx.RegisterFunc(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(),
funcMaker{{$x.Name}})
+{{- end}}
+{{- range $x := .Emitters}}
+ exec.RegisterEmitter(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(),
emitMaker{{$x.Name}})
+{{- end}}
+{{- range $x := .Inputs}}
+ exec.RegisterInput(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(),
iterMaker{{$x.Name}})
+{{- end}}
+}
+
+{{range $x := .Shims -}}
+type caller{{$x.Name}} struct {
+ fn {{$x.Type}}
+}
+
+func funcMaker{{$x.Name}}(fn interface{}) reflectx.Func {
+ f := fn.({{$x.Type}})
+ return &caller{{$x.Name}}{fn: f}
+}
+
+func (c *caller{{$x.Name}}) Name() string {
+ return reflectx.FunctionName(c.fn)
+}
+
+func (c *caller{{$x.Name}}) Type() reflect.Type {
+ return reflect.TypeOf(c.fn)
+}
+
+func (c *caller{{$x.Name}}) Call(args []interface{}) []interface{} {
+ {{mktuplef (len $x.Out) "out%d"}}{{- if len $x.Out}} := {{end
-}}c.fn({{mkparams "args[%d].(%v)" $x.In}})
+ return []interface{}{ {{- mktuplef (len $x.Out) "out%d" -}} }
+}
+
+func (c *caller{{$x.Name}}) Call{{len $x.In}}x{{len $x.Out}}({{mkargs (len
$x.In) "arg%v" "interface{}"}}) ({{- mktuple (len $x.Out) "interface{}"}}) {
+ {{if len $x.Out}}return {{end}}c.fn({{mkparams "arg%d.(%v)" $x.In}})
+}
+
+{{end}}
+{{if .Emitters -}}
+type emitNative struct {
+ n exec.ElementProcessor
+ fn interface{}
+
+ ctx context.Context
+ ws []typex.Window
+ et typex.EventTime
+}
+
+func (e *emitNative) Init(ctx context.Context, ws []typex.Window, et
typex.EventTime) error {
+ e.ctx = ctx
+ e.ws = ws
+ e.et = et
+ return nil
+}
+
+func (e *emitNative) Value() interface{} {
+ return e.fn
+}
+
+{{end -}}
+{{range $x := .Emitters -}}
+func emitMaker{{$x.Name}}(n exec.ElementProcessor) exec.ReusableEmitter {
+ ret := &emitNative{n: n}
+ ret.fn = ret.invoke{{.Name}}
+ return ret
+}
+
+func (e *emitNative) invoke{{$x.Name}}({{if $x.Time -}} t typex.EventTime,
{{end}}{{if $x.Key}}key {{$x.Key}}, {{end}}val {{$x.Val}}) {
+ value := exec.FullValue{Windows: e.ws, Timestamp: {{- if $x.Time}}
t{{else}} e.et{{end}}, {{- if $x.Key}} Elm: key, Elm2: val {{else}} Elm:
val{{end -}} }
+ if err := e.n.ProcessElement(e.ctx, value); err != nil {
+ panic(err)
+ }
+}
+
+{{end}}
+{{- if .Inputs -}}
+type iterNative struct {
+ s exec.ReStream
+ fn interface{}
+
+ // cur is the "current" stream, if any.
+ cur exec.Stream
+}
+
+func (v *iterNative) Init() error {
+ cur, err := v.s.Open()
+ if err != nil {
+ return err
+ }
+ v.cur = cur
+ return nil
+}
+
+func (v *iterNative) Value() interface{} {
+ return v.fn
+}
+
+func convToString(v interface{}) string {
+ switch v.(type) {
+ case []byte:
+ return string(v.([]byte))
+ default:
+ return v.(string)
+ }
+}
+
+func (v *iterNative) Reset() error {
+ if err := v.cur.Close(); err != nil {
+ return err
+ }
+ v.cur = nil
+ return nil
+}
+{{- end}}
+{{- range $x := .Inputs}}
+func iterMaker{{$x.Name}}(s exec.ReStream) exec.ReusableInput {
+ ret := &iterNative{s: s}
+ ret.fn = ret.read{{$x.Name}}
+ return ret
+}
+
+func (v *iterNative) read{{$x.Name}}({{if $x.Time -}} et *typex.EventTime,
{{end}}{{if $x.Key}}key *{{$x.Key}}, {{end}}value *{{$x.Val}}) bool {
+ elm, err := v.cur.Read()
+ if err != nil {
+ if err == io.EOF {
+ return false
+ }
+ panic(fmt.Sprintf("broken stream: %v", err))
+ }
+
+{{- if $x.Time}}
+ *et = elm.Timestamp
+{{- end}}
+{{- if eq $x.Key "string"}}
+ *key = convToString(elm.Elm)
+{{- else if $x.Key}}
+ *key = elm.Elm.({{$x.Key}})
+{{- end}}
+{{- if eq $x.Val "string"}}
+ *value = convToString(elm.Elm{{- if $x.Key -}} 2 {{- end -}})
+{{- else}}
+ *value = elm.Elm{{- if $x.Key -}} 2 {{- end -}}.({{$x.Val}})
+{{- end}}
+ return true
+}
+{{- end}}
+
+// DO NOT MODIFY: GENERATED CODE
+`))
+
+// funcMap contains useful functions for use in the template.
+var funcMap template.FuncMap = map[string]interface{}{
+ "mkargs": mkargs,
+ "mkparams": mkparams,
+ "mktuple": mktuple,
+ "mktuplef": mktuplef,
+}
+
+// mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format,
n-1)> type".
+// If n is 0, it returns the empty string.
+func mkargs(n int, format, typ string) string {
+ if n == 0 {
+ return ""
+ }
+ return fmt.Sprintf("%v %v", mktuplef(n, format), typ)
+}
+
+// mkparams(format, []type) returns "<fmt.Sprintf(format, 0, type[0])>, ..,
<fmt.Sprintf(format, n-1, type[0])>".
+func mkparams(format string, types []string) string {
+ var ret []string
+ for i, t := range types {
+ ret = append(ret, fmt.Sprintf(format, i, t))
+ }
+ return strings.Join(ret, ", ")
+}
+
+// mktuple(n, v) returns "v, v, ..., v".
+func mktuple(n int, v string) string {
+ var ret []string
+ for i := 0; i < n; i++ {
+ ret = append(ret, v)
+ }
+ return strings.Join(ret, ", ")
+}
+
+// mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, ..,
<fmt.Sprintf(format, n-1)>"
+func mktuplef(n int, format string) string {
+ var ret []string
+ for i := 0; i < n; i++ {
+ ret = append(ret, fmt.Sprintf(format, i))
+ }
+ return strings.Join(ret, ", ")
+}
diff --git a/sdks/go/pkg/beam/util/shimx/generate_test.go
b/sdks/go/pkg/beam/util/shimx/generate_test.go
new file mode 100644
index 00000000000..3696bbab7f3
--- /dev/null
+++ b/sdks/go/pkg/beam/util/shimx/generate_test.go
@@ -0,0 +1,217 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shimx
+
+import (
+ "bytes"
+ "sort"
+ "testing"
+)
+
+func TestTop_Sort(t *testing.T) {
+ top := Top{
+ Imports: []string{"z", "a", "r"},
+ Functions: []string{"z", "a", "r"},
+ Types: []string{"z", "a", "r"},
+ Emitters: []Emitter{{Name: "z"}, {Name: "a"}, {Name: "r"}},
+ Inputs: []Input{{Name: "z"}, {Name: "a"}, {Name: "r"}},
+ Shims: []Func{{Name: "z"}, {Name: "a"}, {Name: "r"}},
+ }
+
+ top.sort()
+ if !sort.SliceIsSorted(top.Imports, func(i, j int) bool { return
top.Imports[i] < top.Imports[j] }) {
+ t.Errorf("top.Imports not sorted: got %v, want it sorted",
top.Imports)
+ }
+ if !sort.SliceIsSorted(top.Functions, func(i, j int) bool { return
top.Functions[i] < top.Functions[j] }) {
+ t.Errorf("top.Types not sorted: got %v, want it sorted",
top.Functions)
+ }
+ if !sort.SliceIsSorted(top.Types, func(i, j int) bool { return
top.Types[i] < top.Types[j] }) {
+ t.Errorf("top.Types not sorted: got %v, want it sorted",
top.Types)
+ }
+ if !sort.SliceIsSorted(top.Emitters, func(i, j int) bool { return
top.Emitters[i].Name < top.Emitters[j].Name }) {
+ t.Errorf("top.Emitters not sorted by name: got %v, want it
sorted", top.Emitters)
+ }
+ if !sort.SliceIsSorted(top.Inputs, func(i, j int) bool { return
top.Inputs[i].Name < top.Inputs[j].Name }) {
+ t.Errorf("top.Inputs not sorted by name: got %v, want it
sorted", top.Inputs)
+ }
+ if !sort.SliceIsSorted(top.Shims, func(i, j int) bool { return
top.Shims[i].Name < top.Shims[j].Name }) {
+ t.Errorf("top.Shims not sorted: got %v, want it sorted",
top.Shims)
+ }
+}
+
+func TestTop_ProcessImports(t *testing.T) {
+ needsFiltering := []string{"context", "keepit", "fmt", "io", "reflect",
"unrelated"}
+
+ tests := []struct {
+ name string
+ got *Top
+ want []string
+ }{
+ {name: "default", got: &Top{}, want: []string{"context",
"keepit", "fmt", "io", "unrelated"}},
+ {name: "emit", got: &Top{Emitters: []Emitter{{Name: "emit"}}},
want: []string{ExecImport, TypexImport, "keepit", "fmt", "io", "unrelated"}},
+ {name: "iter", got: &Top{Inputs: []Input{{Name: "iter"}}},
want: []string{ExecImport, "context", "keepit", "unrelated"}},
+ {name: "iterWTime", got: &Top{Inputs: []Input{{Name:
"iterWTime", Time: true}}}, want: []string{ExecImport, TypexImport, "context",
"keepit", "unrelated"}},
+ {name: "shim", got: &Top{Shims: []Func{{Name: "emit"}}}, want:
[]string{ReflectxImport, "context", "keepit", "fmt", "io", "unrelated"}},
+ {name: "iter&emit", got: &Top{Emitters: []Emitter{{Name:
"emit"}}, Inputs: []Input{{Name: "iter"}}}, want: []string{ExecImport,
TypexImport, "keepit", "unrelated"}},
+ {name: "functions", got: &Top{Functions: []string{"func1"}},
want: []string{RuntimeImport, "context", "keepit", "fmt", "io", "unrelated"}},
+ {name: "types", got: &Top{Types: []string{"func1"}}, want:
[]string{RuntimeImport, "context", "keepit", "fmt", "io", "unrelated"}},
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ top := test.got
+ top.Imports = needsFiltering
+ top = top.processImports()
+ for i := range top.Imports {
+ if top.Imports[i] != test.want[i] {
+ t.Fatalf("want %v, got %v", test.want,
top.Imports)
+ }
+ }
+ })
+ }
+}
+
+func TestName(t *testing.T) {
+ tests := []struct {
+ have, want string
+ }{
+ {"int", "Int"},
+ {"foo.MyInt", "Foo۰MyInt"},
+ {"[]beam.X", "Typex۰XSlice"},
+ {"map[int]beam.X", "Map_int_typex۰X"},
+ {"map[string]*beam.X", "Map_string_Ꮨtypex۰X"},
+ }
+ for _, test := range tests {
+ if got := Name(test.have); got != test.want {
+ t.Errorf("Name(%v) = %v, want %v", test.have, got,
test.want)
+ }
+ }
+}
+
+func TestFuncName(t *testing.T) {
+ tests := []struct {
+ in, out []string
+ want string
+ }{
+ {in: []string{"Int"}, out: []string{"Int"}, want: "IntГInt"},
+ {in: []string{"Int"}, out: []string{}, want: "IntГ"},
+ {in: []string{}, out: []string{"Bool"}, want: "ГBool"},
+ {in: []string{"Bool", "String"}, out: []string{"Int", "Bool"},
want: "BoolStringГIntBool"},
+ {in: []string{"String", "Map_int_typex۰X"}, out:
[]string{"Int", "Typex۰XSlice"}, want: "StringMap_int_typex۰XГIntTypex۰XSlice"},
+ }
+ for _, test := range tests {
+ if got := FuncName(test.in, test.out); got != test.want {
+ t.Errorf("FuncName(%v,%v) = %v, want %v", test.in,
test.out, got, test.want)
+ }
+ }
+}
+
+func TestFile(t *testing.T) {
+ top := Top{
+ Package: "gentest",
+ Imports: []string{"z", "a", "r"},
+ Functions: []string{"z", "a", "r"},
+ Types: []string{"z", "a", "r"},
+ Emitters: []Emitter{
+ {Name: "z", Type: "func(int)", Val: "Int"},
+ {Name: "a", Type: "func(bool, int) bool", Key: "bool",
Val: "int"},
+ {Name: "r", Type: "func(typex.EventTime, string) bool",
Time: true, Val: "string"},
+ },
+ Inputs: []Input{
+ {Name: "z", Type: "func(*int) bool"},
+ {Name: "a", Type: "func(*bool, *int) bool", Key:
"bool", Val: "int"},
+ {Name: "r", Type: "func(*typex.EventTime, *string)
bool", Time: true, Val: "string"},
+ },
+ Shims: []Func{
+ {Name: "z", Type: "func(string, func(int))", In:
[]string{"string", "func(int)"}},
+ {Name: "a", Type: "func(float64) (int, int)", In:
[]string{"float64"}, Out: []string{"int", "int"}},
+ {Name: "r", Type: "func(string, func(int))", In:
[]string{"string", "func(int)"}},
+ },
+ }
+ top.sort()
+
+ var b bytes.Buffer
+ if err := vampireTemplate.Funcs(funcMap).Execute(&b, top); err != nil {
+ t.Errorf("error generating template: %v", err)
+ }
+}
+
+func TestMkargs(t *testing.T) {
+ tests := []struct {
+ n int
+ format, typ string
+ want string
+ }{
+ {n: 0, format: "Foo", typ: "Bar", want: ""},
+ {n: 1, format: "arg%d", typ: "Bar", want: "arg0 Bar"},
+ {n: 4, format: "a%d", typ: "Baz", want: "a0, a1, a2, a3 Baz"},
+ }
+ for _, test := range tests {
+ if got := mkargs(test.n, test.format, test.typ); got !=
test.want {
+ t.Errorf("mkargs(%v,%v,%v) = %v, want %v", test.n,
test.format, test.typ, got, test.want)
+ }
+ }
+}
+
+func TestMkparams(t *testing.T) {
+ tests := []struct {
+ format string
+ types []string
+ want string
+ }{
+ {format: "Foo", types: []string{}, want: ""},
+ {format: "arg%d %v", types: []string{"Bar"}, want: "arg0 Bar"},
+ {format: "a%d %v", types: []string{"Foo", "Baz",
"interface{}"}, want: "a0 Foo, a1 Baz, a2 interface{}"},
+ }
+ for _, test := range tests {
+ if got := mkparams(test.format, test.types); got != test.want {
+ t.Errorf("mkparams(%v,%v) = %v, want %v", test.format,
test.types, got, test.want)
+ }
+ }
+}
+
+func TestMktuple(t *testing.T) {
+ tests := []struct {
+ n int
+ v string
+ want string
+ }{
+ {n: 0, v: "Foo", want: ""},
+ {n: 1, v: "Bar", want: "Bar"},
+ {n: 4, v: "Baz", want: "Baz, Baz, Baz, Baz"},
+ }
+ for _, test := range tests {
+ if got := mktuple(test.n, test.v); got != test.want {
+ t.Errorf("mktuple(%v,%v) = %v, want %v", test.n,
test.v, got, test.want)
+ }
+ }
+}
+
+func TestMktuplef(t *testing.T) {
+ tests := []struct {
+ n int
+ format, typ string
+ want string
+ }{
+ {n: 0, format: "Foo%d", want: ""},
+ {n: 1, format: "arg%d", want: "arg0"},
+ {n: 4, format: "a%d", want: "a0, a1, a2, a3"},
+ }
+ for _, test := range tests {
+ if got := mktuplef(test.n, test.format); got != test.want {
+ t.Errorf("mktuplef(%v,%v) = %v, want %v", test.n,
test.format, got, test.want)
+ }
+ }
+}
diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go
b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go
new file mode 100644
index 00000000000..003a3c91df7
--- /dev/null
+++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go
@@ -0,0 +1,562 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package starcgenx is a Static Analysis Type Assertion shim and Registration
Code Generator
+// which provides an extractor to extract types from a package, in order to
generate
+// approprate shimsr a package so code can be generated for it.
+//
+// It's written for use by the starcgen tool, but separate to permit
+// alternative "go/importer" Importers for accessing types from imported
packages.
+package starcgenx
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "strings"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/util/shimx"
+)
+
+// NewExtractor returns an extractor for the given package.
+func NewExtractor(pkg string) *Extractor {
+ return &Extractor{
+ Package: pkg,
+ functions: make(map[string]struct{}),
+ types: make(map[string]struct{}),
+ funcs: make(map[string]*types.Signature),
+ emits: make(map[string]shimx.Emitter),
+ iters: make(map[string]shimx.Input),
+ imports: make(map[string]struct{}),
+ allExported: true,
+ }
+}
+
+// Extractor contains and uniquifies the cache of types and things that need
to be generated.
+type Extractor struct {
+ w bytes.Buffer
+ Package string
+ debug bool
+
+ // Ids is an optional slice of package local identifiers
+ Ids []string
+
+ // Register and uniquify the needed shims for each kind.
+ // Functions to Register
+ functions map[string]struct{}
+ // Types to Register (structs, essentially)
+ types map[string]struct{}
+ // FuncShims needed
+ funcs map[string]*types.Signature
+ // Emitter Shims needed
+ emits map[string]shimx.Emitter
+ // Iterator Shims needed
+ iters map[string]shimx.Input
+
+ // list of packages we need to import.
+ imports map[string]struct{}
+
+ allExported bool // Marks if all ptransforms are exported and available
in main.
+}
+
+// Summary prints out a summary of the shims and registrations to
+// be generated to the buffer.
+func (e *Extractor) Summary() {
+ e.Print("\n")
+ e.Print("Summary\n")
+ e.Printf("All exported?: %v\n", e.allExported)
+ e.Printf("%d\t Functions\n", len(e.functions))
+ e.Printf("%d\t Types\n", len(e.types))
+ e.Printf("%d\t Shims\n", len(e.funcs))
+ e.Printf("%d\t Emits\n", len(e.emits))
+ e.Printf("%d\t Inputs\n", len(e.iters))
+}
+
+// lifecycleMethodName returns if the passed in string is one of the lifecycle
method names used
+// by the Go SDK as DoFn or CombineFn lifecycle methods. These are the only
methods that need
+// shims generated for them, as per beam/core/graph/fn.go
+// TODO(lostluck): Move this to beam/core/graph/fn.go, so it can stay up to
date.
+func lifecycleMethodName(n string) bool {
+ switch n {
+ case "ProcessElement", "StartBundle", "FinishBundle", "Setup",
"Teardown", "CreateAccumulator", "AddInput", "MergeAccumulators",
"ExtractOutput", "Compact":
+ return true
+ default:
+ return false
+ }
+}
+
+// Bytes forwards to fmt.Fprint to the extractor buffer.
+func (e *Extractor) Bytes() []byte {
+ return e.w.Bytes()
+}
+
+// Print forwards to fmt.Fprint to the extractor buffer.
+func (e *Extractor) Print(s string) {
+ if e.debug {
+ fmt.Fprint(&e.w, s)
+ }
+}
+
+// Printf forwards to fmt.Printf to the extractor buffer.
+func (e *Extractor) Printf(f string, args ...interface{}) {
+ if e.debug {
+ fmt.Fprintf(&e.w, f, args...)
+ }
+}
+
+// FromAsts analyses the contents of a package
+func (e *Extractor) FromAsts(imp types.Importer, fset *token.FileSet, files
[]*ast.File) error {
+ conf := types.Config{
+ Importer: imp,
+ IgnoreFuncBodies: true,
+ DisableUnusedImportCheck: true,
+ }
+ info := &types.Info{
+ Defs: make(map[*ast.Ident]types.Object),
+ }
+ if len(e.Ids) != 0 {
+ // TODO(lostluck): This becomes unnnecessary iff we can figure
out
+ // which ParDos are being passed to beam.ParDo or beam.Combine.
+ // If there are ids, we need to also look at function bodies,
and uses.
+ var checkFuncBodies bool
+ for _, v := range e.Ids {
+ if strings.Contains(v, ".") {
+ checkFuncBodies = true
+ break
+ }
+ }
+ conf.IgnoreFuncBodies = !checkFuncBodies
+ info.Uses = make(map[*ast.Ident]types.Object)
+ }
+
+ if _, err := conf.Check(e.Package, fset, files, info); err != nil {
+ return fmt.Errorf("failed to type check package %s : %v",
e.Package, err)
+ }
+
+ e.Print("/*\n")
+ var idsRequired, idsFound map[string]bool
+ if len(e.Ids) > 0 {
+ e.Printf("Filtering by %d identifiers: %q\n", len(e.Ids),
strings.Join(e.Ids, ", "))
+ idsRequired = make(map[string]bool)
+ idsFound = make(map[string]bool)
+ for _, id := range e.Ids {
+ idsRequired[id] = true
+ }
+ }
+ e.Print("CHECKING DEFS\n")
+ for id, obj := range info.Defs {
+ e.fromObj(fset, id, obj, idsRequired, idsFound)
+ }
+ e.Print("CHECKING USES\n")
+ for id, obj := range info.Uses {
+ e.fromObj(fset, id, obj, idsRequired, idsFound)
+ }
+ var notFound []string
+ for _, k := range e.Ids {
+ if !idsFound[k] {
+ notFound = append(notFound, k)
+ }
+ }
+ if len(notFound) > 0 {
+ return fmt.Errorf("couldn't find the following identifiers;
please check for typos, or remove them: %v", strings.Join(notFound, ", "))
+ }
+ e.Print("*/\n")
+
+ return nil
+}
+
+func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired,
idsFound map[string]bool) bool {
+ if idsRequired == nil {
+ return true
+ }
+ if idsFound == nil {
+ panic("broken invariant: idsFound map is nil, but idsRequired
map exists")
+ }
+ // If we're filtering IDs, then it needs to be in the filtered
identifiers,
+ // or it's receiver type identifier needs to be in the filtered
identifiers.
+ if idsRequired[ident] {
+ idsFound[ident] = true
+ return true
+ }
+ // Check if this is a function.
+ sig, ok := obj.Type().(*types.Signature)
+ if !ok {
+ return false
+ }
+ // If this is a function, and it has a receiver, it's a method.
+ if recv := sig.Recv(); recv != nil && lifecycleMethodName(ident) {
+ // We don't want to care about pointers, so dereference to
value type.
+ t := recv.Type()
+ p, ok := t.(*types.Pointer)
+ for ok {
+ t = p.Elem()
+ p, ok = t.(*types.Pointer)
+ }
+ ts := types.TypeString(t, e.qualifier)
+ e.Printf("RRR has %v, ts: %s %s--- ", sig, ts, ident)
+ if !idsRequired[ts] {
+ e.Print("IGNORE\n")
+ return false
+ }
+ e.Print("KEEP\n")
+ idsFound[ts] = true
+ return true
+ }
+ return false
+}
+
+func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj
types.Object, idsRequired, idsFound map[string]bool) {
+ if obj == nil { // Omit the package declaration.
+ e.Printf("%s: %q has no object, probably a package\n",
+ fset.Position(id.Pos()), id.Name)
+ return
+ }
+
+ pkg := obj.Pkg()
+ if pkg == nil {
+ e.Printf("%s: %q has no package \n",
+ fset.Position(id.Pos()), id.Name)
+ // No meaningful identifier.
+ return
+ }
+ ident := fmt.Sprintf("%s.%s", pkg.Name(), obj.Name())
+ if pkg.Name() == e.Package {
+ ident = obj.Name()
+ }
+ if !e.isRequired(ident, obj, idsRequired, idsFound) {
+ return
+ }
+
+ switch ot := obj.(type) {
+ case *types.Var:
+ // Vars are tricky since they could be anything, and anywhere
(package scope, parameters, etc)
+ // eg. Flags, or Field Tags, among others.
+ // I'm increasingly convinced that we should simply igonore
vars.
+ // Do nothing for vars.
+ case *types.Func:
+ sig := obj.Type().(*types.Signature)
+ if recv := sig.Recv(); recv != nil {
+ // Methods don't need registering, but they do need
shim generation.
+ e.Printf("%s: %q is a method of %v -> %v--- %T %v %v
%v\n",
+ fset.Position(id.Pos()), id.Name, recv.Type(),
obj, obj, id, obj.Pkg(), obj.Type())
+ if !lifecycleMethodName(id.Name) {
+ // If this is not a lifecycle method, we should
ignore it.
+ return
+ }
+ } else if id.Name != "init" {
+ // init functions are special and should be ignored.
+ // Functions need registering, as well as shim
generation.
+ e.Printf("%s: %q is a top level func %v --- %T %v %v
%v\n",
+ fset.Position(id.Pos()), ident, obj, obj, id,
obj.Pkg(), obj.Type())
+ e.functions[ident] = struct{}{}
+ }
+ // For functions from other packages.
+ if pkg.Name() != e.Package {
+ e.imports[pkg.Path()] = struct{}{}
+ }
+
+ e.funcs[e.sigKey(sig)] = sig
+ e.extractParameters(sig)
+ e.Printf("\t%v\n", sig)
+ case *types.TypeName:
+ e.Printf("%s: %q is a type %v --- %T %v %v %v %v\n",
+ fset.Position(id.Pos()), id.Name, obj, obj, id,
obj.Pkg(), obj.Type(), obj.Name())
+ // Probably need to sanity check that this type actually is/has
a ProcessElement
+ // or MergeAccumulators defined for this type so unnecessary
registrations don't happen,
+ // an can explicitly produce an error if an explicitly named
type *isn't* a DoFn or CombineFn.
+ e.extractType(ot)
+ default:
+ e.Printf("%s: %q defines %v --- %T %v %v %v\n",
+ fset.Position(id.Pos()), types.ObjectString(obj,
e.qualifier), obj, obj, id, obj.Pkg(), obj.Type())
+ }
+}
+
+func (e *Extractor) extractType(ot *types.TypeName) {
+ name := types.TypeString(ot.Type(), e.qualifier)
+ // Unwrap an alias by one level.
+ // Attempting to deference a full chain of aliases runs the risk of
crossing
+ // a visibility boundary such as internal packages.
+ // A single level is safe since the code we're analysing imports it,
+ // so we can assume the generated code can access it too.
+ if ot.IsAlias() {
+ if t, ok := ot.Type().(*types.Named); ok {
+ ot = t.Obj()
+ name = types.TypeString(t, e.qualifier)
+
+ if pkg := t.Obj().Pkg(); pkg != nil {
+ e.imports[pkg.Path()] = struct{}{}
+ }
+ }
+ }
+ e.types[name] = struct{}{}
+}
+
+// Examines the signature and extracts types of parameters for generating
+// necessary imports and emitter and iterator code.
+func (e *Extractor) extractParameters(sig *types.Signature) {
+ in := sig.Params() // *types.Tuple
+ for i := 0; i < in.Len(); i++ {
+ s := in.At(i) // *types.Var
+
+ // Pointer types need to be iteratively unwrapped until we're
at the base type,
+ // so we can get the import if necessary.
+ t := s.Type()
+ p, ok := t.(*types.Pointer)
+ for ok {
+ t = p.Elem()
+ p, ok = t.(*types.Pointer)
+ }
+ // Here's were we ensure we register new imports.
+ if t, ok := t.(*types.Named); ok {
+ if pkg := t.Obj().Pkg(); pkg != nil {
+ e.imports[pkg.Path()] = struct{}{}
+ }
+ e.extractType(t.Obj())
+ }
+
+ if a, ok := s.Type().(*types.Signature); ok {
+ // Check if the type is an emitter or iterator for the
specialized
+ // shim generation for those types.
+ if emt, ok := e.makeEmitter(a); ok {
+ e.emits[emt.Name] = emt
+ }
+ if ipt, ok := e.makeInput(a); ok {
+ e.iters[ipt.Name] = ipt
+ }
+ // Tail recurse on functional parameters.
+ e.extractParameters(a)
+ }
+ }
+}
+
+func (e *Extractor) qualifier(pkg *types.Package) string {
+ n := tail(pkg.Name())
+ if n == e.Package {
+ return ""
+ }
+ return n
+}
+
+func tail(path string) string {
+ if i := strings.LastIndex("/", path); i >= 0 {
+ path = path[i:]
+ }
+ return path
+}
+
+func (e *Extractor) tupleStrings(t *types.Tuple) []string {
+ var vs []string
+ for i := 0; i < t.Len(); i++ {
+ v := t.At(i)
+ vs = append(vs, types.TypeString(v.Type(), e.qualifier))
+ }
+ return vs
+}
+
+// sigKey produces a variable name agnostic key for the function signature.
+func (e *Extractor) sigKey(sig *types.Signature) string {
+ ps, rs := e.tupleStrings(sig.Params()), e.tupleStrings(sig.Results())
+ return fmt.Sprintf("func(%v) (%v)", strings.Join(ps, ","),
strings.Join(rs, ","))
+}
+
+// Generate produces an additional file for the Go package that was extracted,
+// to be included in a subsequent compilation.
+func (e *Extractor) Generate(filename string) []byte {
+ var functions []string
+ for fn := range e.functions {
+ // No extra processing necessary, since these should all be
package local.
+ functions = append(functions, fn)
+ }
+ var typs []string
+ for t := range e.types {
+ typs = append(typs, t)
+ }
+ var shims []shimx.Func
+ for sig, t := range e.funcs {
+ shim := shimx.Func{Type: sig}
+ var inNames []string
+ in := t.Params() // *types.Tuple
+ for i := 0; i < in.Len(); i++ {
+ s := in.At(i) // *types.Var
+ shim.In = append(shim.In, types.TypeString(s.Type(),
e.qualifier))
+ inNames = append(inNames, e.NameType(s.Type()))
+ }
+ var outNames []string
+ out := t.Results() // *types.Tuple
+ for i := 0; i < out.Len(); i++ {
+ s := out.At(i)
+ shim.Out = append(shim.Out, types.TypeString(s.Type(),
e.qualifier))
+ outNames = append(outNames, e.NameType(s.Type()))
+ }
+ shim.Name = shimx.FuncName(inNames, outNames)
+ shims = append(shims, shim)
+ }
+ var emits []shimx.Emitter
+ for _, t := range e.emits {
+ emits = append(emits, t)
+ }
+ var inputs []shimx.Input
+ for _, t := range e.iters {
+ inputs = append(inputs, t)
+ }
+
+ var imports []string
+ for k := range e.imports {
+ if k == "" || k == e.Package {
+ continue
+ }
+ imports = append(imports, k)
+ }
+
+ top := shimx.Top{
+ FileName: filename,
+ ToolName: "starcgen",
+ Package: e.Package,
+ Imports: imports,
+ Functions: functions,
+ Types: typs,
+ Shims: shims,
+ Emitters: emits,
+ Inputs: inputs,
+ }
+ e.Print("\n")
+ shimx.File(&e.w, &top)
+ return e.w.Bytes()
+}
+
+func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) {
+ // Emitters must have no return values.
+ if sig.Results().Len() != 0 {
+ return shimx.Emitter{}, false
+ }
+ p := sig.Params()
+ emt := shimx.Emitter{Type: e.sigKey(sig)}
+ switch p.Len() {
+ case 1:
+ emt.Time = false
+ emt.Val = e.varString(p.At(0))
+ case 2:
+ // TODO(rebo): Fix this when imports are resolved.
+ // This is the tricky one... Need to verify what happens with
aliases.
+ // And get a candle to compare this against somehwere.
isEventTime(p.At(0)) maybe.
+ // if p.At(0) == typex.EventTimeType {
+ // emt.Time = true
+ // } else {
+ emt.Key = e.varString(p.At(0))
+ //}
+ emt.Val = e.varString(p.At(1))
+ case 3:
+ // If there's 3, the first one must be typex.EventTime.
+ emt.Time = true
+ emt.Key = e.varString(p.At(1))
+ emt.Val = e.varString(p.At(2))
+ default:
+ return shimx.Emitter{}, false
+ }
+ if emt.Time {
+ emt.Name = fmt.Sprintf("ET%s%s", shimx.Name(emt.Key),
shimx.Name(emt.Val))
+ } else {
+ emt.Name = fmt.Sprintf("%s%s", shimx.Name(emt.Key),
shimx.Name(emt.Val))
+ }
+ return emt, true
+}
+
+// makeInput checks if the given signature is an iterator or not, and if so,
+// returns a shimx.Input struct for the signature for use by the code
+// generator. The canonical check for an iterater signature is in the
+// funcx.UnfoldIter function which uses the reflect library,
+// and this logic is replicated here.
+func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) {
+ r := sig.Results()
+ if r.Len() != 1 {
+ return shimx.Input{}, false
+ }
+ // Iterators must return a bool.
+ if b, ok := r.At(0).Type().(*types.Basic); !ok || b.Kind() !=
types.Bool {
+ return shimx.Input{}, false
+ }
+ p := sig.Params()
+ for i := 0; i < p.Len(); i++ {
+ // All params for iterators must be pointers.
+ if _, ok := p.At(i).Type().(*types.Pointer); !ok {
+ return shimx.Input{}, false
+ }
+ }
+ itr := shimx.Input{Type: e.sigKey(sig)}
+ switch p.Len() {
+ case 1:
+ itr.Time = false
+ itr.Val = e.deref(p.At(0))
+ case 2:
+ // TODO(rebo): Fix this when imports are resolved.
+ // This is the tricky one... Need to verify what happens with
aliases.
+ // And get a candle to compare this against somehwere.
isEventTime(p.At(0)) maybe.
+ // if p.At(0) == typex.EventTimeType {
+ // itr.Time = true
+ // } else {
+ itr.Key = e.deref(p.At(0))
+ //}
+ itr.Val = e.deref(p.At(1))
+ case 3:
+ // If there's 3, the first one must be typex.EventTime.
+ itr.Time = true
+ itr.Key = e.deref(p.At(1))
+ itr.Val = e.deref(p.At(2))
+ default:
+ return shimx.Input{}, false
+ }
+ if itr.Time {
+ itr.Name = fmt.Sprintf("ET%s%s", shimx.Name(itr.Key),
shimx.Name(itr.Val))
+ } else {
+ itr.Name = fmt.Sprintf("%s%s", shimx.Name(itr.Key),
shimx.Name(itr.Val))
+ }
+ return itr, true
+}
+
+// deref returns the string identifier for the element type of a pointer var.
+// deref panics if the var type is not a pointer.
+func (e *Extractor) deref(v *types.Var) string {
+ p := v.Type().(*types.Pointer)
+ return types.TypeString(p.Elem(), e.qualifier)
+}
+
+// varString provides the correct type for a variable within the
+// package for which we're generated code.
+func (e *Extractor) varString(v *types.Var) string {
+ return types.TypeString(v.Type(), e.qualifier)
+}
+
+// NameType turns a reflect.Type into a strying based on it's name.
+// It prefixes Emit or Iter if the function satisfies the constrains of those
types.
+func (e *Extractor) NameType(t types.Type) string {
+ switch a := t.(type) {
+ case *types.Signature:
+ if emt, ok := e.makeEmitter(a); ok {
+ return "Emit" + emt.Name
+ }
+ if ipt, ok := e.makeInput(a); ok {
+ return "Iter" + ipt.Name
+ }
+ return shimx.Name(e.sigKey(a))
+ case *types.Pointer:
+ return e.NameType(a.Elem())
+ case *types.Slice:
+ return "Sliceof" + e.NameType(a.Elem())
+ default:
+ return shimx.Name(types.TypeString(t, e.qualifier))
+ }
+}
diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go
b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go
new file mode 100644
index 00000000000..9141acb114e
--- /dev/null
+++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package starcgenx
+
+import (
+ "go/ast"
+ "go/importer"
+ "go/parser"
+ "go/token"
+ "strings"
+ "testing"
+)
+
+func TestExtractor(t *testing.T) {
+ tests := []struct {
+ name string
+ pkg string
+ files []string
+ ids []string
+ expected []string
+ excluded []string
+ }{
+ {name: "pardo1", files: []string{pardo}, pkg: "pardo",
+ expected: []string{"runtime.RegisterFunction(MyIdent)",
"runtime.RegisterFunction(MyDropVal)", "runtime.RegisterFunction(MyOtherDoFn)",
"runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())",
"funcMakerStringГString", "funcMakerIntStringГInt", "funcMakerFooГStringFoo"},
+ },
+ {name: "emits1", files: []string{emits}, pkg: "emits",
+ expected:
[]string{"runtime.RegisterFunction(anotherFn)",
"runtime.RegisterFunction(emitFn)",
"runtime.RegisterType(reflect.TypeOf((*reInt)(nil)).Elem())",
"funcMakerEmitIntIntГ", "emitMakerIntInt", "funcMakerIntIntEmitIntIntГError"},
+ },
+ {name: "iters1", files: []string{iters}, pkg: "iters",
+ expected: []string{"runtime.RegisterFunction(iterFn)",
"funcMakerStringIterIntГ", "iterMakerInt"},
+ },
+ {name: "structs1", files: []string{structs}, pkg: "structs",
ids: []string{"myDoFn"},
+ expected:
[]string{"runtime.RegisterType(reflect.TypeOf((*myDoFn)(nil)).Elem())",
"funcMakerEmitIntГ", "emitMakerInt", "funcMakerValTypeValTypeEmitIntГ",
"runtime.RegisterType(reflect.TypeOf((*valType)(nil)).Elem())"},
+ excluded: []string{"funcMakerStringГ",
"emitMakerString", "nonPipelineType"},
+ },
+ }
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ fset := token.NewFileSet()
+ var fs []*ast.File
+ for i, f := range test.files {
+ n, err := parser.ParseFile(fset, "", f, 0)
+ if err != nil {
+ t.Fatalf("couldn't parse
test.files[%d]: %v", i, err)
+ }
+ fs = append(fs, n)
+ }
+ e := NewExtractor(test.pkg)
+ e.Ids = test.ids
+ if err := e.FromAsts(importer.Default(), fset, fs); err
!= nil {
+ t.Fatal(err)
+ }
+ data := e.Generate("test_shims.go")
+ s := string(data)
+ for _, i := range test.expected {
+ if !strings.Contains(s, i) {
+ t.Errorf("expected %q in generated
file", i)
+ }
+ }
+ for _, i := range test.excluded {
+ if strings.Contains(s, i) {
+ t.Errorf("found %q in generated file",
i)
+ }
+ }
+ t.Log(s)
+ })
+ }
+}
+
+const pardo = `
+package pardo
+
+func MyIdent(v string) string {
+ return v
+}
+
+func MyDropVal(k int,v string) int {
+ return k
+}
+
+// A user defined type
+type foo struct{}
+
+func MyOtherDoFn(v foo) (string,foo) {
+ return "constant"
+}
+`
+
+const emits = `
+package emits
+
+type reInt int
+
+func anotherFn(emit func(int,int)) {
+ emit(0, 0)
+}
+
+func emitFn(k,v int, emit func(int,int)) error {
+ for i := 0; i < v; i++ { emit(k, i) }
+ return nil
+}
+`
+const iters = `
+package iters
+
+func iterFn(k string, iters func(*int) bool) {}
+`
+
+const structs = `
+package structs
+
+type myDoFn struct{}
+
+// valType should be picked up via processElement
+type valType int
+
+func (f *myDoFn) ProcessElement(k, v valType, emit func(int)) {}
+
+func (f *myDoFn) Setup(emit func(int)) {}
+func (f *myDoFn) StartBundle(emit func(int)) {}
+func (f *myDoFn) FinishBundle(emit func(int)) error {}
+func (f *myDoFn) Teardown(emit func(int)) {}
+
+type nonPipelineType int
+
+// UnrelatedMethods shouldn't have shims or tangents generated for them
+func (f *myDoFn) UnrelatedMethod1(v string) {}
+func (f *myDoFn) UnrelatedMethod2(notEmit func(string)) {}
+
+func (f *myDoFn) UnrelatedMethod3(notEmit func(nonPipelineType)) {}
+`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 165044)
Time Spent: 6h 20m (was: 6h 10m)
> Make it easy to generate type-specialized Go SDK reflectx.Funcs
> ---------------------------------------------------------------
>
> Key: BEAM-3612
> URL: https://issues.apache.org/jira/browse/BEAM-3612
> Project: Beam
> Issue Type: Improvement
> Components: sdk-go
> Reporter: Henning Rohde
> Assignee: Robert Burke
> Priority: Major
> Time Spent: 6h 20m
> Remaining Estimate: 0h
>
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)