This is an automated email from the ASF dual-hosted git repository. altay pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new cf06c4d [BEAM-3612] Add a shim generator tool (#7000) cf06c4d is described below commit cf06c4d7ba8078e933798c9656e3e98566222b1c Author: Robert Burke <lostl...@users.noreply.github.com> AuthorDate: Mon Nov 12 10:03:10 2018 -0800 [BEAM-3612] Add a shim generator tool (#7000) * [BEAM-3612] Add a shim generator tool --- sdks/go/cmd/starcgen/starcgen.go | 154 ++++++ sdks/go/cmd/starcgen/starcgen_test.go | 123 +++++ sdks/go/pkg/beam/util/shimx/generate.go | 413 ++++++++++++++++ sdks/go/pkg/beam/util/shimx/generate_test.go | 217 +++++++++ sdks/go/pkg/beam/util/starcgenx/starcgenx.go | 562 ++++++++++++++++++++++ sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go | 145 ++++++ 6 files changed, 1614 insertions(+) diff --git a/sdks/go/cmd/starcgen/starcgen.go b/sdks/go/cmd/starcgen/starcgen.go new file mode 100644 index 0000000..87e8011 --- /dev/null +++ b/sdks/go/cmd/starcgen/starcgen.go @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// starcgen is a tool to generate specialized type assertion shims to be +// used in Apache Beam Go SDK pipelines instead of the default reflection shim. +// This is done through static analysis of go sources for the package in question. +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/apache/beam/sdks/go/pkg/beam/util/starcgenx" +) + +var ( + inputs = flag.String("inputs", "", "comma separated list of file with types to create") + output = flag.String("output", "", "output file with types to create") + ids = flag.String("identifiers", "", "comma separated list of package local identifiers for which to generate code") +) + +// Generate takes the typechecked inputs, and generates the shim file for the relevant +// identifiers. +func Generate(w io.Writer, filename, pkg string, ids []string, fset *token.FileSet, files []*ast.File) error { + e := starcgenx.NewExtractor(pkg) + e.Ids = ids + + // Importing from source should work in most cases. + imp := importer.For("source", nil) + if err := e.FromAsts(imp, fset, files); err != nil { + // Always print out the debugging info to the file. + if _, errw := w.Write(e.Bytes()); errw != nil { + return fmt.Errorf("error writing debug data to file after err %v:%v", err, errw) + } + return fmt.Errorf("error extracting from asts: %v", err) + } + + e.Print("*/\n") + data := e.Generate(filename) + if err := write(w, []byte(license)); err != nil { + return err + } + return write(w, data) +} + +func write(w io.Writer, data []byte) error { + n, err := w.Write(data) + if err != nil && n < len(data) { + return fmt.Errorf("short write of data got %d, want %d", n, len(data)) + } + return err +} + +func usage() { + fmt.Fprintf(os.Stderr, "Usage: %v [options] --inputs=<comma separated of go files>\n", filepath.Base(os.Args[0])) + flag.PrintDefaults() +} + +func main() { + flag.Usage = usage + flag.Parse() + + log.SetFlags(log.Lshortfile) + log.SetPrefix("starcgen: ") + + ipts := strings.Split(*inputs, ",") + fset := token.NewFileSet() + var fs []*ast.File + var pkg string + + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + log.Fatal(err) + } + + for _, i := range ipts { + f, err := parser.ParseFile(fset, i, nil, 0) + if err != nil { + err1 := err + f, err = parser.ParseFile(fset, filepath.Join(dir, i), nil, 0) + if err != nil { + log.Print(err1) + log.Fatal(err) // parse error + } + } + + if pkg == "" { + pkg = f.Name.Name + } else if pkg != f.Name.Name { + log.Fatalf("Input file %v has mismatched package path, got %q, want %q", i, f.Name.Name, pkg) + } + fs = append(fs, f) + } + if pkg == "" { + log.Fatalf("No package detected in input files: %v", inputs) + } + + if *output == "" { + name := pkg + if len(ipts) == 1 { + name = filepath.Base(ipts[0]) + if index := strings.Index(name, "."); index > 0 { + name = name[:index] + } + } + *output = filepath.Join(filepath.Dir(ipts[0]), name+".shims.go") + } + + f, err := os.OpenFile(*output, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Fatal(err) + } + if err := Generate(f, *output, pkg, strings.Split(*ids, ","), fset, fs); err != nil { + log.Fatal(err) + } +} + +const license = `// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +` diff --git a/sdks/go/cmd/starcgen/starcgen_test.go b/sdks/go/cmd/starcgen/starcgen_test.go new file mode 100644 index 0000000..7282ada --- /dev/null +++ b/sdks/go/cmd/starcgen/starcgen_test.go @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "go/ast" + "go/parser" + "go/token" + "strings" + "testing" +) + +func TestGenerate(t *testing.T) { + tests := []struct { + name string + pkg string + files []string + ids []string + expected []string + excluded []string + }{ + {name: "genAllSingleFile", files: []string{hello1}, pkg: "hello", ids: []string{}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerContext۰ContextStringГString", "funcMakerFooГString"}, + }, + {name: "genSpecificSingleFile", files: []string{hello1}, pkg: "hello", ids: []string{"MyTitle"}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "funcMakerContext۰ContextStringГString"}, + excluded: []string{"MyOtherDoFn", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerFooГString"}, + }, + {name: "genAllMultiFile", files: []string{hello1, hello2}, pkg: "hello", ids: []string{}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterFunction(anotherFn)", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerContext۰ContextStringГString", "funcMakerFooГString", "funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГFoo"}, + }, + {name: "genSpecificMultiFile1", files: []string{hello1, hello2}, pkg: "hello", ids: []string{"MyTitle"}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "funcMakerContext۰ContextStringГString"}, + excluded: []string{"MyOtherDoFn", "anotherFn", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerFooГString", "funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГFoo"}, + }, + {name: "genSpecificMultiFile2", files: []string{hello1, hello2}, pkg: "hello", ids: []string{"anotherFn"}, + expected: []string{"funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГString"}, + excluded: []string{"MyOtherDoFn", "MyTitle", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerFooГString"}, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + fset := token.NewFileSet() + var fs []*ast.File + for i, f := range test.files { + n, err := parser.ParseFile(fset, "", f, 0) + if err != nil { + t.Fatalf("couldn't parse test.files[%d]: %v", i, err) + } + fs = append(fs, n) + } + var b bytes.Buffer + if err := Generate(&b, test.name+".go", test.pkg, test.ids, fset, fs); err != nil { + t.Fatal(err) + } + s := string(b.Bytes()) + for _, i := range test.expected { + if !strings.Contains(s, i) { + t.Errorf("expected %q in generated file", i) + } + } + for _, i := range test.excluded { + if strings.Contains(s, i) { + t.Errorf("found %q in generated file", i) + } + } + t.Log(s) + }) + } +} + +const hello1 = ` +package hello + +import ( + "context" + "strings" +) + +func MyTitle(ctx context.Context, v string) string { + return strings.Title(v) +} + +type foo struct{} + +func MyOtherDoFn(v foo) string { + return "constant" +} +` + +const hello2 = ` +package hello + +import ( + "context" + "strings" + + "github.com/apache/beam/sdks/go/pkg/beam/util/shimx" +) + +func anotherFn(v shimx.Emitter) string { + return v.Name +} + +func fooFn(v shimx.Emitter) foo { + return foo{} +} +` diff --git a/sdks/go/pkg/beam/util/shimx/generate.go b/sdks/go/pkg/beam/util/shimx/generate.go new file mode 100644 index 0000000..6c0eb4a --- /dev/null +++ b/sdks/go/pkg/beam/util/shimx/generate.go @@ -0,0 +1,413 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package shimx specifies the templates for generating type assertion shims for +// Apache Beam Go SDK pipelines. +// +// In particular, the shims are used by the Beam Go SDK to avoid reflection at runtime, +// which is the default mode of operation. The shims are specialized for the code +// in question, using type assertion to convert arguments as required, and invoke the +// user code. +// +// Similar shims are required for emitters, and iterators in order to propagate values +// out of, and in to user functions respectively without reflection overhead. +// +// Registering user types is required to support user types as PCollection +// types, while registering functions is required to avoid possibly expensive function +// resolution at worker start up, which defaults to using DWARF Symbol tables. +// +// The generator largely relies on basic types and strings to ensure that it's usable +// by dynamic processes via reflection, or by any static analysis approach that is +// used in the future. +package shimx + +import ( + "fmt" + "io" + "sort" + "strings" + "text/template" +) + +// Beam imports that the generated code requires. +var ( + ExecImport = "github.com/apache/beam/sdks/go/pkg/beam/core/runtime/exec" + TypexImport = "github.com/apache/beam/sdks/go/pkg/beam/core/typex" + ReflectxImport = "github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx" + RuntimeImport = "github.com/apache/beam/sdks/go/pkg/beam/core/runtime" +) + +func validateBeamImports() { + checkImportSuffix(ExecImport, "exec") + checkImportSuffix(TypexImport, "typex") + checkImportSuffix(ReflectxImport, "reflectx") + checkImportSuffix(RuntimeImport, "runtime") +} + +func checkImportSuffix(path, suffix string) { + if !strings.HasSuffix(path, suffix) { + panic(fmt.Sprintf("expected %v to end with %v. can't generate valid code", path, suffix)) + } +} + +// Top is the top level inputs into the template file for generating shims. +type Top struct { + FileName, ToolName, Package string + + Imports []string // the full import paths + Functions []string // the plain names of the functions to be registered. + Types []string // the plain names of the types to be registered. + Emitters []Emitter + Inputs []Input + Shims []Func +} + +// sort orders the shims consistently to minimize diffs in the generated code. +func (t *Top) sort() { + sort.Strings(t.Imports) + sort.Strings(t.Functions) + sort.Strings(t.Types) + sort.SliceStable(t.Emitters, func(i, j int) bool { + return t.Emitters[i].Name < t.Emitters[j].Name + }) + sort.SliceStable(t.Inputs, func(i, j int) bool { + return t.Inputs[i].Name < t.Inputs[j].Name + }) + sort.SliceStable(t.Shims, func(i, j int) bool { + return t.Shims[i].Name < t.Shims[j].Name + }) +} + +// processImports removes imports that are otherwise handled by the template +// This method is on the value to shallow copy the Field references to avoid +// mutating the user provided instance. +func (t Top) processImports() *Top { + pred := map[string]bool{"reflect": true} + var filtered []string + if len(t.Emitters) > 0 { + pred["context"] = true + } + if len(t.Inputs) > 0 { + pred["fmt"] = true + pred["io"] = true + } + if len(t.Types) > 0 || len(t.Functions) > 0 { + filtered = append(filtered, RuntimeImport) + pred[RuntimeImport] = true + } + if len(t.Shims) > 0 { + filtered = append(filtered, ReflectxImport) + pred[ReflectxImport] = true + } + if len(t.Emitters) > 0 || len(t.Inputs) > 0 { + filtered = append(filtered, ExecImport) + pred[ExecImport] = true + } + needTypexImport := len(t.Emitters) > 0 + for _, i := range t.Inputs { + if i.Time { + needTypexImport = true + break + } + } + if needTypexImport { + filtered = append(filtered, TypexImport) + pred[TypexImport] = true + } + for _, imp := range t.Imports { + if !pred[imp] { + filtered = append(filtered, imp) + } + } + t.Imports = filtered + return &t +} + +// Emitter represents an emitter shim to be generated. +type Emitter struct { + Name, Type string // The user name of the function, the type of the emit. + Time bool // if this uses event time. + Key, Val string // The type of the emits. +} + +// Input represents an iterator shim to be generated. +type Input struct { + Name, Type string // The user name of the function, the type of the iterator (including the bool). + Time bool // if this uses event time. + Key, Val string // The type of the inputs, pointers removed. +} + +// Func represents a type assertion shim for function invocation to be generated. +type Func struct { + Name, Type string + In, Out []string +} + +// Name creates a capitalized identifier from a type string. The identifier +// follows the rules of go identifiers and should be compileable. +// See https://golang.org/ref/spec#Identifiers for details. +func Name(t string) string { + if strings.HasPrefix(t, "[]") { + return Name(t[2:] + "Slice") + } + + t = strings.Replace(t, "beam.", "typex.", -1) + t = strings.Replace(t, ".", "۰", -1) // For packages + t = strings.Replace(t, "*", "Ꮨ", -1) // For pointers + t = strings.Replace(t, "[", "_", -1) // For maps + t = strings.Replace(t, "]", "_", -1) + return strings.Title(t) +} + +// FuncName returns a compilable Go identifier for a function, given valid +// type names as generated by Name. +// See https://golang.org/ref/spec#Identifiers for details. +func FuncName(inNames, outNames []string) string { + return fmt.Sprintf("%sГ%s", strings.Join(inNames, ""), strings.Join(outNames, "")) +} + +// File writes go code to the given writer. +func File(w io.Writer, top *Top) { + validateBeamImports() + top = top.processImports() + top.sort() + vampireTemplate.Funcs(funcMap).Execute(w, top) +} + +var vampireTemplate = template.Must(template.New("vampire").Funcs(funcMap).Parse(`// Code generated by {{.ToolName}}. DO NOT EDIT. +// File: {{.FileName}} + +package {{.Package}} + +import ( + +{{- if .Emitters}} + "context" +{{- end}} +{{- if .Inputs}} + "fmt" + "io" +{{- end}} + "reflect" +{{- if .Imports}} + + // Library imports +{{- end}} +{{- range $import := .Imports}} + "{{$import}}" +{{- end}} +) + +func init() { +{{- range $x := .Functions}} + runtime.RegisterFunction({{$x}}) +{{- end}} +{{- range $x := .Types}} + runtime.RegisterType(reflect.TypeOf((*{{$x}})(nil)).Elem()) +{{- end}} +{{- range $x := .Shims}} + reflectx.RegisterFunc(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), funcMaker{{$x.Name}}) +{{- end}} +{{- range $x := .Emitters}} + exec.RegisterEmitter(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), emitMaker{{$x.Name}}) +{{- end}} +{{- range $x := .Inputs}} + exec.RegisterInput(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), iterMaker{{$x.Name}}) +{{- end}} +} + +{{range $x := .Shims -}} +type caller{{$x.Name}} struct { + fn {{$x.Type}} +} + +func funcMaker{{$x.Name}}(fn interface{}) reflectx.Func { + f := fn.({{$x.Type}}) + return &caller{{$x.Name}}{fn: f} +} + +func (c *caller{{$x.Name}}) Name() string { + return reflectx.FunctionName(c.fn) +} + +func (c *caller{{$x.Name}}) Type() reflect.Type { + return reflect.TypeOf(c.fn) +} + +func (c *caller{{$x.Name}}) Call(args []interface{}) []interface{} { + {{mktuplef (len $x.Out) "out%d"}}{{- if len $x.Out}} := {{end -}}c.fn({{mkparams "args[%d].(%v)" $x.In}}) + return []interface{}{ {{- mktuplef (len $x.Out) "out%d" -}} } +} + +func (c *caller{{$x.Name}}) Call{{len $x.In}}x{{len $x.Out}}({{mkargs (len $x.In) "arg%v" "interface{}"}}) ({{- mktuple (len $x.Out) "interface{}"}}) { + {{if len $x.Out}}return {{end}}c.fn({{mkparams "arg%d.(%v)" $x.In}}) +} + +{{end}} +{{if .Emitters -}} +type emitNative struct { + n exec.ElementProcessor + fn interface{} + + ctx context.Context + ws []typex.Window + et typex.EventTime +} + +func (e *emitNative) Init(ctx context.Context, ws []typex.Window, et typex.EventTime) error { + e.ctx = ctx + e.ws = ws + e.et = et + return nil +} + +func (e *emitNative) Value() interface{} { + return e.fn +} + +{{end -}} +{{range $x := .Emitters -}} +func emitMaker{{$x.Name}}(n exec.ElementProcessor) exec.ReusableEmitter { + ret := &emitNative{n: n} + ret.fn = ret.invoke{{.Name}} + return ret +} + +func (e *emitNative) invoke{{$x.Name}}({{if $x.Time -}} t typex.EventTime, {{end}}{{if $x.Key}}key {{$x.Key}}, {{end}}val {{$x.Val}}) { + value := exec.FullValue{Windows: e.ws, Timestamp: {{- if $x.Time}} t{{else}} e.et{{end}}, {{- if $x.Key}} Elm: key, Elm2: val {{else}} Elm: val{{end -}} } + if err := e.n.ProcessElement(e.ctx, value); err != nil { + panic(err) + } +} + +{{end}} +{{- if .Inputs -}} +type iterNative struct { + s exec.ReStream + fn interface{} + + // cur is the "current" stream, if any. + cur exec.Stream +} + +func (v *iterNative) Init() error { + cur, err := v.s.Open() + if err != nil { + return err + } + v.cur = cur + return nil +} + +func (v *iterNative) Value() interface{} { + return v.fn +} + +func convToString(v interface{}) string { + switch v.(type) { + case []byte: + return string(v.([]byte)) + default: + return v.(string) + } +} + +func (v *iterNative) Reset() error { + if err := v.cur.Close(); err != nil { + return err + } + v.cur = nil + return nil +} +{{- end}} +{{- range $x := .Inputs}} +func iterMaker{{$x.Name}}(s exec.ReStream) exec.ReusableInput { + ret := &iterNative{s: s} + ret.fn = ret.read{{$x.Name}} + return ret +} + +func (v *iterNative) read{{$x.Name}}({{if $x.Time -}} et *typex.EventTime, {{end}}{{if $x.Key}}key *{{$x.Key}}, {{end}}value *{{$x.Val}}) bool { + elm, err := v.cur.Read() + if err != nil { + if err == io.EOF { + return false + } + panic(fmt.Sprintf("broken stream: %v", err)) + } + +{{- if $x.Time}} + *et = elm.Timestamp +{{- end}} +{{- if eq $x.Key "string"}} + *key = convToString(elm.Elm) +{{- else if $x.Key}} + *key = elm.Elm.({{$x.Key}}) +{{- end}} +{{- if eq $x.Val "string"}} + *value = convToString(elm.Elm{{- if $x.Key -}} 2 {{- end -}}) +{{- else}} + *value = elm.Elm{{- if $x.Key -}} 2 {{- end -}}.({{$x.Val}}) +{{- end}} + return true +} +{{- end}} + +// DO NOT MODIFY: GENERATED CODE +`)) + +// funcMap contains useful functions for use in the template. +var funcMap template.FuncMap = map[string]interface{}{ + "mkargs": mkargs, + "mkparams": mkparams, + "mktuple": mktuple, + "mktuplef": mktuplef, +} + +// mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)> type". +// If n is 0, it returns the empty string. +func mkargs(n int, format, typ string) string { + if n == 0 { + return "" + } + return fmt.Sprintf("%v %v", mktuplef(n, format), typ) +} + +// mkparams(format, []type) returns "<fmt.Sprintf(format, 0, type[0])>, .., <fmt.Sprintf(format, n-1, type[0])>". +func mkparams(format string, types []string) string { + var ret []string + for i, t := range types { + ret = append(ret, fmt.Sprintf(format, i, t)) + } + return strings.Join(ret, ", ") +} + +// mktuple(n, v) returns "v, v, ..., v". +func mktuple(n int, v string) string { + var ret []string + for i := 0; i < n; i++ { + ret = append(ret, v) + } + return strings.Join(ret, ", ") +} + +// mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)>" +func mktuplef(n int, format string) string { + var ret []string + for i := 0; i < n; i++ { + ret = append(ret, fmt.Sprintf(format, i)) + } + return strings.Join(ret, ", ") +} diff --git a/sdks/go/pkg/beam/util/shimx/generate_test.go b/sdks/go/pkg/beam/util/shimx/generate_test.go new file mode 100644 index 0000000..3696bba --- /dev/null +++ b/sdks/go/pkg/beam/util/shimx/generate_test.go @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shimx + +import ( + "bytes" + "sort" + "testing" +) + +func TestTop_Sort(t *testing.T) { + top := Top{ + Imports: []string{"z", "a", "r"}, + Functions: []string{"z", "a", "r"}, + Types: []string{"z", "a", "r"}, + Emitters: []Emitter{{Name: "z"}, {Name: "a"}, {Name: "r"}}, + Inputs: []Input{{Name: "z"}, {Name: "a"}, {Name: "r"}}, + Shims: []Func{{Name: "z"}, {Name: "a"}, {Name: "r"}}, + } + + top.sort() + if !sort.SliceIsSorted(top.Imports, func(i, j int) bool { return top.Imports[i] < top.Imports[j] }) { + t.Errorf("top.Imports not sorted: got %v, want it sorted", top.Imports) + } + if !sort.SliceIsSorted(top.Functions, func(i, j int) bool { return top.Functions[i] < top.Functions[j] }) { + t.Errorf("top.Types not sorted: got %v, want it sorted", top.Functions) + } + if !sort.SliceIsSorted(top.Types, func(i, j int) bool { return top.Types[i] < top.Types[j] }) { + t.Errorf("top.Types not sorted: got %v, want it sorted", top.Types) + } + if !sort.SliceIsSorted(top.Emitters, func(i, j int) bool { return top.Emitters[i].Name < top.Emitters[j].Name }) { + t.Errorf("top.Emitters not sorted by name: got %v, want it sorted", top.Emitters) + } + if !sort.SliceIsSorted(top.Inputs, func(i, j int) bool { return top.Inputs[i].Name < top.Inputs[j].Name }) { + t.Errorf("top.Inputs not sorted by name: got %v, want it sorted", top.Inputs) + } + if !sort.SliceIsSorted(top.Shims, func(i, j int) bool { return top.Shims[i].Name < top.Shims[j].Name }) { + t.Errorf("top.Shims not sorted: got %v, want it sorted", top.Shims) + } +} + +func TestTop_ProcessImports(t *testing.T) { + needsFiltering := []string{"context", "keepit", "fmt", "io", "reflect", "unrelated"} + + tests := []struct { + name string + got *Top + want []string + }{ + {name: "default", got: &Top{}, want: []string{"context", "keepit", "fmt", "io", "unrelated"}}, + {name: "emit", got: &Top{Emitters: []Emitter{{Name: "emit"}}}, want: []string{ExecImport, TypexImport, "keepit", "fmt", "io", "unrelated"}}, + {name: "iter", got: &Top{Inputs: []Input{{Name: "iter"}}}, want: []string{ExecImport, "context", "keepit", "unrelated"}}, + {name: "iterWTime", got: &Top{Inputs: []Input{{Name: "iterWTime", Time: true}}}, want: []string{ExecImport, TypexImport, "context", "keepit", "unrelated"}}, + {name: "shim", got: &Top{Shims: []Func{{Name: "emit"}}}, want: []string{ReflectxImport, "context", "keepit", "fmt", "io", "unrelated"}}, + {name: "iter&emit", got: &Top{Emitters: []Emitter{{Name: "emit"}}, Inputs: []Input{{Name: "iter"}}}, want: []string{ExecImport, TypexImport, "keepit", "unrelated"}}, + {name: "functions", got: &Top{Functions: []string{"func1"}}, want: []string{RuntimeImport, "context", "keepit", "fmt", "io", "unrelated"}}, + {name: "types", got: &Top{Types: []string{"func1"}}, want: []string{RuntimeImport, "context", "keepit", "fmt", "io", "unrelated"}}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + top := test.got + top.Imports = needsFiltering + top = top.processImports() + for i := range top.Imports { + if top.Imports[i] != test.want[i] { + t.Fatalf("want %v, got %v", test.want, top.Imports) + } + } + }) + } +} + +func TestName(t *testing.T) { + tests := []struct { + have, want string + }{ + {"int", "Int"}, + {"foo.MyInt", "Foo۰MyInt"}, + {"[]beam.X", "Typex۰XSlice"}, + {"map[int]beam.X", "Map_int_typex۰X"}, + {"map[string]*beam.X", "Map_string_Ꮨtypex۰X"}, + } + for _, test := range tests { + if got := Name(test.have); got != test.want { + t.Errorf("Name(%v) = %v, want %v", test.have, got, test.want) + } + } +} + +func TestFuncName(t *testing.T) { + tests := []struct { + in, out []string + want string + }{ + {in: []string{"Int"}, out: []string{"Int"}, want: "IntГInt"}, + {in: []string{"Int"}, out: []string{}, want: "IntГ"}, + {in: []string{}, out: []string{"Bool"}, want: "ГBool"}, + {in: []string{"Bool", "String"}, out: []string{"Int", "Bool"}, want: "BoolStringГIntBool"}, + {in: []string{"String", "Map_int_typex۰X"}, out: []string{"Int", "Typex۰XSlice"}, want: "StringMap_int_typex۰XГIntTypex۰XSlice"}, + } + for _, test := range tests { + if got := FuncName(test.in, test.out); got != test.want { + t.Errorf("FuncName(%v,%v) = %v, want %v", test.in, test.out, got, test.want) + } + } +} + +func TestFile(t *testing.T) { + top := Top{ + Package: "gentest", + Imports: []string{"z", "a", "r"}, + Functions: []string{"z", "a", "r"}, + Types: []string{"z", "a", "r"}, + Emitters: []Emitter{ + {Name: "z", Type: "func(int)", Val: "Int"}, + {Name: "a", Type: "func(bool, int) bool", Key: "bool", Val: "int"}, + {Name: "r", Type: "func(typex.EventTime, string) bool", Time: true, Val: "string"}, + }, + Inputs: []Input{ + {Name: "z", Type: "func(*int) bool"}, + {Name: "a", Type: "func(*bool, *int) bool", Key: "bool", Val: "int"}, + {Name: "r", Type: "func(*typex.EventTime, *string) bool", Time: true, Val: "string"}, + }, + Shims: []Func{ + {Name: "z", Type: "func(string, func(int))", In: []string{"string", "func(int)"}}, + {Name: "a", Type: "func(float64) (int, int)", In: []string{"float64"}, Out: []string{"int", "int"}}, + {Name: "r", Type: "func(string, func(int))", In: []string{"string", "func(int)"}}, + }, + } + top.sort() + + var b bytes.Buffer + if err := vampireTemplate.Funcs(funcMap).Execute(&b, top); err != nil { + t.Errorf("error generating template: %v", err) + } +} + +func TestMkargs(t *testing.T) { + tests := []struct { + n int + format, typ string + want string + }{ + {n: 0, format: "Foo", typ: "Bar", want: ""}, + {n: 1, format: "arg%d", typ: "Bar", want: "arg0 Bar"}, + {n: 4, format: "a%d", typ: "Baz", want: "a0, a1, a2, a3 Baz"}, + } + for _, test := range tests { + if got := mkargs(test.n, test.format, test.typ); got != test.want { + t.Errorf("mkargs(%v,%v,%v) = %v, want %v", test.n, test.format, test.typ, got, test.want) + } + } +} + +func TestMkparams(t *testing.T) { + tests := []struct { + format string + types []string + want string + }{ + {format: "Foo", types: []string{}, want: ""}, + {format: "arg%d %v", types: []string{"Bar"}, want: "arg0 Bar"}, + {format: "a%d %v", types: []string{"Foo", "Baz", "interface{}"}, want: "a0 Foo, a1 Baz, a2 interface{}"}, + } + for _, test := range tests { + if got := mkparams(test.format, test.types); got != test.want { + t.Errorf("mkparams(%v,%v) = %v, want %v", test.format, test.types, got, test.want) + } + } +} + +func TestMktuple(t *testing.T) { + tests := []struct { + n int + v string + want string + }{ + {n: 0, v: "Foo", want: ""}, + {n: 1, v: "Bar", want: "Bar"}, + {n: 4, v: "Baz", want: "Baz, Baz, Baz, Baz"}, + } + for _, test := range tests { + if got := mktuple(test.n, test.v); got != test.want { + t.Errorf("mktuple(%v,%v) = %v, want %v", test.n, test.v, got, test.want) + } + } +} + +func TestMktuplef(t *testing.T) { + tests := []struct { + n int + format, typ string + want string + }{ + {n: 0, format: "Foo%d", want: ""}, + {n: 1, format: "arg%d", want: "arg0"}, + {n: 4, format: "a%d", want: "a0, a1, a2, a3"}, + } + for _, test := range tests { + if got := mktuplef(test.n, test.format); got != test.want { + t.Errorf("mktuplef(%v,%v) = %v, want %v", test.n, test.format, got, test.want) + } + } +} diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go new file mode 100644 index 0000000..003a3c9 --- /dev/null +++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package starcgenx is a Static Analysis Type Assertion shim and Registration Code Generator +// which provides an extractor to extract types from a package, in order to generate +// approprate shimsr a package so code can be generated for it. +// +// It's written for use by the starcgen tool, but separate to permit +// alternative "go/importer" Importers for accessing types from imported packages. +package starcgenx + +import ( + "bytes" + "fmt" + "go/ast" + "go/token" + "go/types" + "strings" + + "github.com/apache/beam/sdks/go/pkg/beam/util/shimx" +) + +// NewExtractor returns an extractor for the given package. +func NewExtractor(pkg string) *Extractor { + return &Extractor{ + Package: pkg, + functions: make(map[string]struct{}), + types: make(map[string]struct{}), + funcs: make(map[string]*types.Signature), + emits: make(map[string]shimx.Emitter), + iters: make(map[string]shimx.Input), + imports: make(map[string]struct{}), + allExported: true, + } +} + +// Extractor contains and uniquifies the cache of types and things that need to be generated. +type Extractor struct { + w bytes.Buffer + Package string + debug bool + + // Ids is an optional slice of package local identifiers + Ids []string + + // Register and uniquify the needed shims for each kind. + // Functions to Register + functions map[string]struct{} + // Types to Register (structs, essentially) + types map[string]struct{} + // FuncShims needed + funcs map[string]*types.Signature + // Emitter Shims needed + emits map[string]shimx.Emitter + // Iterator Shims needed + iters map[string]shimx.Input + + // list of packages we need to import. + imports map[string]struct{} + + allExported bool // Marks if all ptransforms are exported and available in main. +} + +// Summary prints out a summary of the shims and registrations to +// be generated to the buffer. +func (e *Extractor) Summary() { + e.Print("\n") + e.Print("Summary\n") + e.Printf("All exported?: %v\n", e.allExported) + e.Printf("%d\t Functions\n", len(e.functions)) + e.Printf("%d\t Types\n", len(e.types)) + e.Printf("%d\t Shims\n", len(e.funcs)) + e.Printf("%d\t Emits\n", len(e.emits)) + e.Printf("%d\t Inputs\n", len(e.iters)) +} + +// lifecycleMethodName returns if the passed in string is one of the lifecycle method names used +// by the Go SDK as DoFn or CombineFn lifecycle methods. These are the only methods that need +// shims generated for them, as per beam/core/graph/fn.go +// TODO(lostluck): Move this to beam/core/graph/fn.go, so it can stay up to date. +func lifecycleMethodName(n string) bool { + switch n { + case "ProcessElement", "StartBundle", "FinishBundle", "Setup", "Teardown", "CreateAccumulator", "AddInput", "MergeAccumulators", "ExtractOutput", "Compact": + return true + default: + return false + } +} + +// Bytes forwards to fmt.Fprint to the extractor buffer. +func (e *Extractor) Bytes() []byte { + return e.w.Bytes() +} + +// Print forwards to fmt.Fprint to the extractor buffer. +func (e *Extractor) Print(s string) { + if e.debug { + fmt.Fprint(&e.w, s) + } +} + +// Printf forwards to fmt.Printf to the extractor buffer. +func (e *Extractor) Printf(f string, args ...interface{}) { + if e.debug { + fmt.Fprintf(&e.w, f, args...) + } +} + +// FromAsts analyses the contents of a package +func (e *Extractor) FromAsts(imp types.Importer, fset *token.FileSet, files []*ast.File) error { + conf := types.Config{ + Importer: imp, + IgnoreFuncBodies: true, + DisableUnusedImportCheck: true, + } + info := &types.Info{ + Defs: make(map[*ast.Ident]types.Object), + } + if len(e.Ids) != 0 { + // TODO(lostluck): This becomes unnnecessary iff we can figure out + // which ParDos are being passed to beam.ParDo or beam.Combine. + // If there are ids, we need to also look at function bodies, and uses. + var checkFuncBodies bool + for _, v := range e.Ids { + if strings.Contains(v, ".") { + checkFuncBodies = true + break + } + } + conf.IgnoreFuncBodies = !checkFuncBodies + info.Uses = make(map[*ast.Ident]types.Object) + } + + if _, err := conf.Check(e.Package, fset, files, info); err != nil { + return fmt.Errorf("failed to type check package %s : %v", e.Package, err) + } + + e.Print("/*\n") + var idsRequired, idsFound map[string]bool + if len(e.Ids) > 0 { + e.Printf("Filtering by %d identifiers: %q\n", len(e.Ids), strings.Join(e.Ids, ", ")) + idsRequired = make(map[string]bool) + idsFound = make(map[string]bool) + for _, id := range e.Ids { + idsRequired[id] = true + } + } + e.Print("CHECKING DEFS\n") + for id, obj := range info.Defs { + e.fromObj(fset, id, obj, idsRequired, idsFound) + } + e.Print("CHECKING USES\n") + for id, obj := range info.Uses { + e.fromObj(fset, id, obj, idsRequired, idsFound) + } + var notFound []string + for _, k := range e.Ids { + if !idsFound[k] { + notFound = append(notFound, k) + } + } + if len(notFound) > 0 { + return fmt.Errorf("couldn't find the following identifiers; please check for typos, or remove them: %v", strings.Join(notFound, ", ")) + } + e.Print("*/\n") + + return nil +} + +func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsFound map[string]bool) bool { + if idsRequired == nil { + return true + } + if idsFound == nil { + panic("broken invariant: idsFound map is nil, but idsRequired map exists") + } + // If we're filtering IDs, then it needs to be in the filtered identifiers, + // or it's receiver type identifier needs to be in the filtered identifiers. + if idsRequired[ident] { + idsFound[ident] = true + return true + } + // Check if this is a function. + sig, ok := obj.Type().(*types.Signature) + if !ok { + return false + } + // If this is a function, and it has a receiver, it's a method. + if recv := sig.Recv(); recv != nil && lifecycleMethodName(ident) { + // We don't want to care about pointers, so dereference to value type. + t := recv.Type() + p, ok := t.(*types.Pointer) + for ok { + t = p.Elem() + p, ok = t.(*types.Pointer) + } + ts := types.TypeString(t, e.qualifier) + e.Printf("RRR has %v, ts: %s %s--- ", sig, ts, ident) + if !idsRequired[ts] { + e.Print("IGNORE\n") + return false + } + e.Print("KEEP\n") + idsFound[ts] = true + return true + } + return false +} + +func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object, idsRequired, idsFound map[string]bool) { + if obj == nil { // Omit the package declaration. + e.Printf("%s: %q has no object, probably a package\n", + fset.Position(id.Pos()), id.Name) + return + } + + pkg := obj.Pkg() + if pkg == nil { + e.Printf("%s: %q has no package \n", + fset.Position(id.Pos()), id.Name) + // No meaningful identifier. + return + } + ident := fmt.Sprintf("%s.%s", pkg.Name(), obj.Name()) + if pkg.Name() == e.Package { + ident = obj.Name() + } + if !e.isRequired(ident, obj, idsRequired, idsFound) { + return + } + + switch ot := obj.(type) { + case *types.Var: + // Vars are tricky since they could be anything, and anywhere (package scope, parameters, etc) + // eg. Flags, or Field Tags, among others. + // I'm increasingly convinced that we should simply igonore vars. + // Do nothing for vars. + case *types.Func: + sig := obj.Type().(*types.Signature) + if recv := sig.Recv(); recv != nil { + // Methods don't need registering, but they do need shim generation. + e.Printf("%s: %q is a method of %v -> %v--- %T %v %v %v\n", + fset.Position(id.Pos()), id.Name, recv.Type(), obj, obj, id, obj.Pkg(), obj.Type()) + if !lifecycleMethodName(id.Name) { + // If this is not a lifecycle method, we should ignore it. + return + } + } else if id.Name != "init" { + // init functions are special and should be ignored. + // Functions need registering, as well as shim generation. + e.Printf("%s: %q is a top level func %v --- %T %v %v %v\n", + fset.Position(id.Pos()), ident, obj, obj, id, obj.Pkg(), obj.Type()) + e.functions[ident] = struct{}{} + } + // For functions from other packages. + if pkg.Name() != e.Package { + e.imports[pkg.Path()] = struct{}{} + } + + e.funcs[e.sigKey(sig)] = sig + e.extractParameters(sig) + e.Printf("\t%v\n", sig) + case *types.TypeName: + e.Printf("%s: %q is a type %v --- %T %v %v %v %v\n", + fset.Position(id.Pos()), id.Name, obj, obj, id, obj.Pkg(), obj.Type(), obj.Name()) + // Probably need to sanity check that this type actually is/has a ProcessElement + // or MergeAccumulators defined for this type so unnecessary registrations don't happen, + // an can explicitly produce an error if an explicitly named type *isn't* a DoFn or CombineFn. + e.extractType(ot) + default: + e.Printf("%s: %q defines %v --- %T %v %v %v\n", + fset.Position(id.Pos()), types.ObjectString(obj, e.qualifier), obj, obj, id, obj.Pkg(), obj.Type()) + } +} + +func (e *Extractor) extractType(ot *types.TypeName) { + name := types.TypeString(ot.Type(), e.qualifier) + // Unwrap an alias by one level. + // Attempting to deference a full chain of aliases runs the risk of crossing + // a visibility boundary such as internal packages. + // A single level is safe since the code we're analysing imports it, + // so we can assume the generated code can access it too. + if ot.IsAlias() { + if t, ok := ot.Type().(*types.Named); ok { + ot = t.Obj() + name = types.TypeString(t, e.qualifier) + + if pkg := t.Obj().Pkg(); pkg != nil { + e.imports[pkg.Path()] = struct{}{} + } + } + } + e.types[name] = struct{}{} +} + +// Examines the signature and extracts types of parameters for generating +// necessary imports and emitter and iterator code. +func (e *Extractor) extractParameters(sig *types.Signature) { + in := sig.Params() // *types.Tuple + for i := 0; i < in.Len(); i++ { + s := in.At(i) // *types.Var + + // Pointer types need to be iteratively unwrapped until we're at the base type, + // so we can get the import if necessary. + t := s.Type() + p, ok := t.(*types.Pointer) + for ok { + t = p.Elem() + p, ok = t.(*types.Pointer) + } + // Here's were we ensure we register new imports. + if t, ok := t.(*types.Named); ok { + if pkg := t.Obj().Pkg(); pkg != nil { + e.imports[pkg.Path()] = struct{}{} + } + e.extractType(t.Obj()) + } + + if a, ok := s.Type().(*types.Signature); ok { + // Check if the type is an emitter or iterator for the specialized + // shim generation for those types. + if emt, ok := e.makeEmitter(a); ok { + e.emits[emt.Name] = emt + } + if ipt, ok := e.makeInput(a); ok { + e.iters[ipt.Name] = ipt + } + // Tail recurse on functional parameters. + e.extractParameters(a) + } + } +} + +func (e *Extractor) qualifier(pkg *types.Package) string { + n := tail(pkg.Name()) + if n == e.Package { + return "" + } + return n +} + +func tail(path string) string { + if i := strings.LastIndex("/", path); i >= 0 { + path = path[i:] + } + return path +} + +func (e *Extractor) tupleStrings(t *types.Tuple) []string { + var vs []string + for i := 0; i < t.Len(); i++ { + v := t.At(i) + vs = append(vs, types.TypeString(v.Type(), e.qualifier)) + } + return vs +} + +// sigKey produces a variable name agnostic key for the function signature. +func (e *Extractor) sigKey(sig *types.Signature) string { + ps, rs := e.tupleStrings(sig.Params()), e.tupleStrings(sig.Results()) + return fmt.Sprintf("func(%v) (%v)", strings.Join(ps, ","), strings.Join(rs, ",")) +} + +// Generate produces an additional file for the Go package that was extracted, +// to be included in a subsequent compilation. +func (e *Extractor) Generate(filename string) []byte { + var functions []string + for fn := range e.functions { + // No extra processing necessary, since these should all be package local. + functions = append(functions, fn) + } + var typs []string + for t := range e.types { + typs = append(typs, t) + } + var shims []shimx.Func + for sig, t := range e.funcs { + shim := shimx.Func{Type: sig} + var inNames []string + in := t.Params() // *types.Tuple + for i := 0; i < in.Len(); i++ { + s := in.At(i) // *types.Var + shim.In = append(shim.In, types.TypeString(s.Type(), e.qualifier)) + inNames = append(inNames, e.NameType(s.Type())) + } + var outNames []string + out := t.Results() // *types.Tuple + for i := 0; i < out.Len(); i++ { + s := out.At(i) + shim.Out = append(shim.Out, types.TypeString(s.Type(), e.qualifier)) + outNames = append(outNames, e.NameType(s.Type())) + } + shim.Name = shimx.FuncName(inNames, outNames) + shims = append(shims, shim) + } + var emits []shimx.Emitter + for _, t := range e.emits { + emits = append(emits, t) + } + var inputs []shimx.Input + for _, t := range e.iters { + inputs = append(inputs, t) + } + + var imports []string + for k := range e.imports { + if k == "" || k == e.Package { + continue + } + imports = append(imports, k) + } + + top := shimx.Top{ + FileName: filename, + ToolName: "starcgen", + Package: e.Package, + Imports: imports, + Functions: functions, + Types: typs, + Shims: shims, + Emitters: emits, + Inputs: inputs, + } + e.Print("\n") + shimx.File(&e.w, &top) + return e.w.Bytes() +} + +func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) { + // Emitters must have no return values. + if sig.Results().Len() != 0 { + return shimx.Emitter{}, false + } + p := sig.Params() + emt := shimx.Emitter{Type: e.sigKey(sig)} + switch p.Len() { + case 1: + emt.Time = false + emt.Val = e.varString(p.At(0)) + case 2: + // TODO(rebo): Fix this when imports are resolved. + // This is the tricky one... Need to verify what happens with aliases. + // And get a candle to compare this against somehwere. isEventTime(p.At(0)) maybe. + // if p.At(0) == typex.EventTimeType { + // emt.Time = true + // } else { + emt.Key = e.varString(p.At(0)) + //} + emt.Val = e.varString(p.At(1)) + case 3: + // If there's 3, the first one must be typex.EventTime. + emt.Time = true + emt.Key = e.varString(p.At(1)) + emt.Val = e.varString(p.At(2)) + default: + return shimx.Emitter{}, false + } + if emt.Time { + emt.Name = fmt.Sprintf("ET%s%s", shimx.Name(emt.Key), shimx.Name(emt.Val)) + } else { + emt.Name = fmt.Sprintf("%s%s", shimx.Name(emt.Key), shimx.Name(emt.Val)) + } + return emt, true +} + +// makeInput checks if the given signature is an iterator or not, and if so, +// returns a shimx.Input struct for the signature for use by the code +// generator. The canonical check for an iterater signature is in the +// funcx.UnfoldIter function which uses the reflect library, +// and this logic is replicated here. +func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) { + r := sig.Results() + if r.Len() != 1 { + return shimx.Input{}, false + } + // Iterators must return a bool. + if b, ok := r.At(0).Type().(*types.Basic); !ok || b.Kind() != types.Bool { + return shimx.Input{}, false + } + p := sig.Params() + for i := 0; i < p.Len(); i++ { + // All params for iterators must be pointers. + if _, ok := p.At(i).Type().(*types.Pointer); !ok { + return shimx.Input{}, false + } + } + itr := shimx.Input{Type: e.sigKey(sig)} + switch p.Len() { + case 1: + itr.Time = false + itr.Val = e.deref(p.At(0)) + case 2: + // TODO(rebo): Fix this when imports are resolved. + // This is the tricky one... Need to verify what happens with aliases. + // And get a candle to compare this against somehwere. isEventTime(p.At(0)) maybe. + // if p.At(0) == typex.EventTimeType { + // itr.Time = true + // } else { + itr.Key = e.deref(p.At(0)) + //} + itr.Val = e.deref(p.At(1)) + case 3: + // If there's 3, the first one must be typex.EventTime. + itr.Time = true + itr.Key = e.deref(p.At(1)) + itr.Val = e.deref(p.At(2)) + default: + return shimx.Input{}, false + } + if itr.Time { + itr.Name = fmt.Sprintf("ET%s%s", shimx.Name(itr.Key), shimx.Name(itr.Val)) + } else { + itr.Name = fmt.Sprintf("%s%s", shimx.Name(itr.Key), shimx.Name(itr.Val)) + } + return itr, true +} + +// deref returns the string identifier for the element type of a pointer var. +// deref panics if the var type is not a pointer. +func (e *Extractor) deref(v *types.Var) string { + p := v.Type().(*types.Pointer) + return types.TypeString(p.Elem(), e.qualifier) +} + +// varString provides the correct type for a variable within the +// package for which we're generated code. +func (e *Extractor) varString(v *types.Var) string { + return types.TypeString(v.Type(), e.qualifier) +} + +// NameType turns a reflect.Type into a strying based on it's name. +// It prefixes Emit or Iter if the function satisfies the constrains of those types. +func (e *Extractor) NameType(t types.Type) string { + switch a := t.(type) { + case *types.Signature: + if emt, ok := e.makeEmitter(a); ok { + return "Emit" + emt.Name + } + if ipt, ok := e.makeInput(a); ok { + return "Iter" + ipt.Name + } + return shimx.Name(e.sigKey(a)) + case *types.Pointer: + return e.NameType(a.Elem()) + case *types.Slice: + return "Sliceof" + e.NameType(a.Elem()) + default: + return shimx.Name(types.TypeString(t, e.qualifier)) + } +} diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go new file mode 100644 index 0000000..9141acb --- /dev/null +++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package starcgenx + +import ( + "go/ast" + "go/importer" + "go/parser" + "go/token" + "strings" + "testing" +) + +func TestExtractor(t *testing.T) { + tests := []struct { + name string + pkg string + files []string + ids []string + expected []string + excluded []string + }{ + {name: "pardo1", files: []string{pardo}, pkg: "pardo", + expected: []string{"runtime.RegisterFunction(MyIdent)", "runtime.RegisterFunction(MyDropVal)", "runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerStringГString", "funcMakerIntStringГInt", "funcMakerFooГStringFoo"}, + }, + {name: "emits1", files: []string{emits}, pkg: "emits", + expected: []string{"runtime.RegisterFunction(anotherFn)", "runtime.RegisterFunction(emitFn)", "runtime.RegisterType(reflect.TypeOf((*reInt)(nil)).Elem())", "funcMakerEmitIntIntГ", "emitMakerIntInt", "funcMakerIntIntEmitIntIntГError"}, + }, + {name: "iters1", files: []string{iters}, pkg: "iters", + expected: []string{"runtime.RegisterFunction(iterFn)", "funcMakerStringIterIntГ", "iterMakerInt"}, + }, + {name: "structs1", files: []string{structs}, pkg: "structs", ids: []string{"myDoFn"}, + expected: []string{"runtime.RegisterType(reflect.TypeOf((*myDoFn)(nil)).Elem())", "funcMakerEmitIntГ", "emitMakerInt", "funcMakerValTypeValTypeEmitIntГ", "runtime.RegisterType(reflect.TypeOf((*valType)(nil)).Elem())"}, + excluded: []string{"funcMakerStringГ", "emitMakerString", "nonPipelineType"}, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + fset := token.NewFileSet() + var fs []*ast.File + for i, f := range test.files { + n, err := parser.ParseFile(fset, "", f, 0) + if err != nil { + t.Fatalf("couldn't parse test.files[%d]: %v", i, err) + } + fs = append(fs, n) + } + e := NewExtractor(test.pkg) + e.Ids = test.ids + if err := e.FromAsts(importer.Default(), fset, fs); err != nil { + t.Fatal(err) + } + data := e.Generate("test_shims.go") + s := string(data) + for _, i := range test.expected { + if !strings.Contains(s, i) { + t.Errorf("expected %q in generated file", i) + } + } + for _, i := range test.excluded { + if strings.Contains(s, i) { + t.Errorf("found %q in generated file", i) + } + } + t.Log(s) + }) + } +} + +const pardo = ` +package pardo + +func MyIdent(v string) string { + return v +} + +func MyDropVal(k int,v string) int { + return k +} + +// A user defined type +type foo struct{} + +func MyOtherDoFn(v foo) (string,foo) { + return "constant" +} +` + +const emits = ` +package emits + +type reInt int + +func anotherFn(emit func(int,int)) { + emit(0, 0) +} + +func emitFn(k,v int, emit func(int,int)) error { + for i := 0; i < v; i++ { emit(k, i) } + return nil +} +` +const iters = ` +package iters + +func iterFn(k string, iters func(*int) bool) {} +` + +const structs = ` +package structs + +type myDoFn struct{} + +// valType should be picked up via processElement +type valType int + +func (f *myDoFn) ProcessElement(k, v valType, emit func(int)) {} + +func (f *myDoFn) Setup(emit func(int)) {} +func (f *myDoFn) StartBundle(emit func(int)) {} +func (f *myDoFn) FinishBundle(emit func(int)) error {} +func (f *myDoFn) Teardown(emit func(int)) {} + +type nonPipelineType int + +// UnrelatedMethods shouldn't have shims or tangents generated for them +func (f *myDoFn) UnrelatedMethod1(v string) {} +func (f *myDoFn) UnrelatedMethod2(notEmit func(string)) {} + +func (f *myDoFn) UnrelatedMethod3(notEmit func(nonPipelineType)) {} +`