This is an automated email from the ASF dual-hosted git repository.
liuhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-go.git
The following commit(s) were added to refs/heads/main by this push:
new 4bee1d4 Add more unit tests (#80)
4bee1d4 is described below
commit 4bee1d47d62b44ddc03b3b3c72dc546b45a5cc0b
Author: mrproliu <[email protected]>
AuthorDate: Thu Jul 20 14:27:36 2023 +0800
Add more unit tests (#80)
---
tools/go-agent/instrument/logger/frameworks/zap.go | 14 +-
tools/go-agent/instrument/logger/instrument.go | 12 +-
.../go-agent/instrument/plugins/enhance_method.go | 6 +-
tools/go-agent/instrument/runtime/instrument.go | 4 +-
tools/go-agent/tools/directive_test.go | 112 ++++++++++
tools/go-agent/tools/dst_test.go | 247 +++++++++++++++++++++
tools/go-agent/tools/enhancement.go | 35 ++-
tools/go-agent/tools/enhancement_test.go | 102 +++++++++
8 files changed, 509 insertions(+), 23 deletions(-)
diff --git a/tools/go-agent/instrument/logger/frameworks/zap.go
b/tools/go-agent/instrument/logger/frameworks/zap.go
index 6af67d3..8554943 100644
--- a/tools/go-agent/instrument/logger/frameworks/zap.go
+++ b/tools/go-agent/instrument/logger/frameworks/zap.go
@@ -150,8 +150,8 @@ func (z *Zap) CustomizedEnhance(path string, curFile
*dst.File, cursor *dstutil.
n.Name.Name == "check" &&
n.Type.Results != nil && len(n.Type.Results.List) > 0 &&
tools.GenerateTypeNameByExp(n.Type.Results.List[0].Type) ==
"*zapcore.CheckedEntry" {
- entryName :=
tools.EnhanceParameterNames(n.Type.Results, true)[0].Name
- recvName := tools.EnhanceParameterNames(n.Recv,
true)[0].Name
+ entryName :=
tools.EnhanceParameterNames(n.Type.Results, tools.FieldListTypeResult)[0].Name
+ recvName := tools.EnhanceParameterNames(n.Recv,
tools.FieldListTypeRecv)[0].Name
// init the zapcore variables
z.initFunction =
tools.GoStringToDecls(fmt.Sprintf(`func initZapCore() {
@@ -171,9 +171,9 @@ zapcore.SWLogEnable = %s
if n.Recv != nil && len(n.Recv.List) == 1 && n.Name.Name ==
"With" &&
n.Type.Params != nil && len(n.Type.Params.List) == 1 &&
tools.GenerateTypeNameByExp(n.Recv.List[0].Type) ==
"*Logger" && tools.GenerateTypeNameByExp(n.Type.Params.List[0].Type) ==
"[]Field" {
- recvs := tools.EnhanceParameterNames(n.Recv, false)
- parameters :=
tools.EnhanceParameterNames(n.Type.Params, false)
- results := tools.EnhanceParameterNames(n.Type.Results,
true)
+ recvs := tools.EnhanceParameterNames(n.Recv,
tools.FieldListTypeRecv)
+ parameters :=
tools.EnhanceParameterNames(n.Type.Params, tools.FieldListTypeParam)
+ results := tools.EnhanceParameterNames(n.Type.Results,
tools.FieldListTypeResult)
return z.enhanceMethod(n, fmt.Sprintf(`defer func() {if
%s != nil { %s.SWFields = %sZap%s(%s, %s.SWFields) }}()`,
results[0].Name, results[0].Name,
rewrite.StaticMethodPrefix, "KnownFieldFilter", parameters[0].Name,
recvs[0].Name)), true
@@ -184,8 +184,8 @@ zapcore.SWLogEnable = %s
n.Name.Name == "Write" &&
n.Type.Params != nil && len(n.Type.Params.List) == 1 &&
tools.GenerateTypeNameByExp(n.Type.Params.List[0].Type)
== "[]Field" {
- recvs := tools.EnhanceParameterNames(n.Recv, false)
- parameters :=
tools.EnhanceParameterNames(n.Type.Params, false)
+ recvs := tools.EnhanceParameterNames(n.Recv,
tools.FieldListTypeRecv)
+ parameters :=
tools.EnhanceParameterNames(n.Type.Params, tools.FieldListTypeParam)
return z.enhanceMethod(n, fmt.Sprintf(`if %s != nil {
%s = %sZapcore%s(%s, %s, %s.SWFields, %s.SWContext, %s.SWContextField,
SWReporterEnable, SWLogEnable, SWReporterLabelKeys) }`,
recvs[0].Name, parameters[0].Name,
rewrite.StaticMethodPrefix, "ReportLogFromZapEntry", recvs[0].Name,
parameters[0].Name, recvs[0].Name,
recvs[0].Name, recvs[0].Name)), true
diff --git a/tools/go-agent/instrument/logger/instrument.go
b/tools/go-agent/instrument/logger/instrument.go
index 38e9caa..bd05234 100644
--- a/tools/go-agent/instrument/logger/instrument.go
+++ b/tools/go-agent/instrument/logger/instrument.go
@@ -117,9 +117,9 @@ func (i *Instrument) addAutomaticBindFunc(path string,
curFile dst.Node, fun *ds
Results []*tools.ParameterInfo
}{
AutomaticBindFuncName: generateFuncName,
- Recvs: tools.EnhanceParameterNames(fun.Recv,
false),
- Parameters:
tools.EnhanceParameterNames(fun.Type.Params, false),
- Results:
tools.EnhanceParameterNames(fun.Type.Results, true),
+ Recvs: tools.EnhanceParameterNames(fun.Recv,
tools.FieldListTypeRecv),
+ Parameters:
tools.EnhanceParameterNames(fun.Type.Params, tools.FieldListTypeParam),
+ Results:
tools.EnhanceParameterNames(fun.Type.Results, tools.FieldListTypeResult),
})
importAnalyzer.AnalyzeNeedsImports(path, fun.Recv)
importAnalyzer.AnalyzeNeedsImports(path, fun.Type.Params)
@@ -307,14 +307,14 @@ func (i *Instrument) addAutomaticFuncDelegators(f
*dst.File, importDecl *dst.Gen
},
}
- for i, recv := range
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Recv, false)
{
+ for i, recv := range
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Recv,
tools.FieldListTypeRecv) {
delegatorFunc.Type.Params.List =
append(delegatorFunc.Type.Params.List, &dst.Field{
Names:
[]*dst.Ident{dst.NewIdent(fmt.Sprintf("recv_%d", i))},
Type: &dst.StarExpr{X: recv.PackagedType()},
})
}
- for i, parameter := range
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Type.Params,
false) {
+ for i, parameter := range
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Type.Params,
tools.FieldListTypeParam) {
packagedType := parameter.PackagedType()
// if the parameter is dynamic list, then change it to
the array type
if el, ok := packagedType.(*dst.Ellipsis); ok {
@@ -326,7 +326,7 @@ func (i *Instrument) addAutomaticFuncDelegators(f
*dst.File, importDecl *dst.Gen
})
}
- for i, result := range
tools.EnhanceParameterNamesWithPackagePrefix(packageName,
fun.Func.Type.Results, true) {
+ for i, result := range
tools.EnhanceParameterNamesWithPackagePrefix(packageName,
fun.Func.Type.Results, tools.FieldListTypeResult) {
delegatorFunc.Type.Params.List =
append(delegatorFunc.Type.Params.List, &dst.Field{
Names:
[]*dst.Ident{dst.NewIdent(fmt.Sprintf("ret_%d", i))},
Type: &dst.StarExpr{X: result.PackagedType()},
diff --git a/tools/go-agent/instrument/plugins/enhance_method.go
b/tools/go-agent/instrument/plugins/enhance_method.go
index 4fdf38a..8ebd85a 100644
--- a/tools/go-agent/instrument/plugins/enhance_method.go
+++ b/tools/go-agent/instrument/plugins/enhance_method.go
@@ -77,11 +77,11 @@ func NewMethodEnhance(inst instrument.Instrument, matcher
*instrument.Point, f *
packageName: pkgName,
InstrumentName: inst.Name(),
InterceptorDefineName: matcher.Interceptor,
- Parameters:
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Params, false),
- Results:
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Results, true),
+ Parameters:
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Params,
tools.FieldListTypeParam),
+ Results:
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Results,
tools.FieldListTypeResult),
}
if f.Recv != nil {
- enhance.Recvs =
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Recv, false)
+ enhance.Recvs =
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Recv,
tools.FieldListTypeRecv)
}
importAnalyzer.AnalyzeNeedsImports(path, f.Type.Params)
diff --git a/tools/go-agent/instrument/runtime/instrument.go
b/tools/go-agent/instrument/runtime/instrument.go
index 08d2073..967ae7c 100644
--- a/tools/go-agent/instrument/runtime/instrument.go
+++ b/tools/go-agent/instrument/runtime/instrument.go
@@ -69,8 +69,8 @@ func (r *Instrument) FilterAndEdit(path string, curFile
*dst.File, cursor *dstut
if len(n.Type.Params.List) != 3 {
return false
}
- parameters := tools.EnhanceParameterNames(n.Type.Params, false)
- results := tools.EnhanceParameterNames(n.Type.Results, true)
+ parameters := tools.EnhanceParameterNames(n.Type.Params,
tools.FieldListTypeParam)
+ results := tools.EnhanceParameterNames(n.Type.Results,
tools.FieldListTypeResult)
tools.InsertStmtsBeforeBody(n.Body, `defer func() {
{{(index .Results 0).Name}}.{{.TLSField}} = goroutineChange({{(index
.Parameters 1).Name}}.{{.TLSField}})
diff --git a/tools/go-agent/tools/directive_test.go
b/tools/go-agent/tools/directive_test.go
new file mode 100644
index 0000000..8d25627
--- /dev/null
+++ b/tools/go-agent/tools/directive_test.go
@@ -0,0 +1,112 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package tools
+
+import (
+ "testing"
+
+ "github.com/apache/skywalking-go/tools/go-agent/instrument/consts"
+)
+
+func TestContainsDirective(t *testing.T) {
+ var tests = []struct {
+ goCode string
+ directive string
+ contains bool
+ }{
+ {
+ goCode: `//skywalking:nocopy
+func test() {}`,
+ directive: consts.DirecitveNoCopy,
+ contains: true,
+ },
+ {
+ goCode: `// test method
+//skywalking:nocopy
+// test method
+func test1() {}`,
+ directive: consts.DirecitveNoCopy,
+ contains: true,
+ },
+ {
+ goCode: `func test1() {}
+//skywalking:nocopy
+`,
+ directive: consts.DirecitveNoCopy,
+ contains: false,
+ },
+ {
+ goCode: `// skywalking:nocopy
+func test2() {}`,
+ directive: consts.DirecitveNoCopy,
+ contains: false,
+ },
+ }
+
+ for _, test := range tests {
+ decls := GoStringToDecls(test.goCode)
+ contains := ContainsDirective(decls[0], test.directive)
+ if contains != test.contains {
+ t.Errorf("ContainsDirective(%s, %s) = %v, excepted %v",
test.goCode, test.directive, contains, test.contains)
+ }
+ }
+}
+
+func TestFindDirective(t *testing.T) {
+ var tests = []struct {
+ goCode string
+ directive string
+ found string
+ }{
+ {
+ goCode: `//skywalking:nocopy
+func test() {}`,
+ directive: consts.DirecitveNoCopy,
+ found: "//skywalking:nocopy",
+ },
+ {
+ goCode: `// test method
+//skywalking:nocopy
+// test method
+func test1() {}`,
+ directive: consts.DirecitveNoCopy,
+ found: "//skywalking:nocopy",
+ },
+ {
+ goCode: `func test1() {}
+//skywalking:nocopy
+`,
+ directive: consts.DirecitveNoCopy,
+ found: "",
+ },
+ {
+ goCode: `//skywalking:native test method
+func test2() {}`,
+ directive: consts.DirectiveNative,
+ found: "//skywalking:native test method",
+ },
+ }
+
+ for _, test := range tests {
+ decls := GoStringToDecls(test.goCode)
+ found := FindDirective(decls[0], test.directive)
+ if found != test.found {
+ t.Errorf("FindDirective(%s, %s) = %v, excepted %v",
test.goCode, test.directive, found, test.found)
+ }
+ }
+}
diff --git a/tools/go-agent/tools/dst_test.go b/tools/go-agent/tools/dst_test.go
new file mode 100644
index 0000000..84b0c65
--- /dev/null
+++ b/tools/go-agent/tools/dst_test.go
@@ -0,0 +1,247 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package tools
+
+import (
+ "fmt"
+ "go/parser"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/dave/dst"
+ "github.com/dave/dst/decorator"
+)
+
+func TestChangePackageImportPath(t *testing.T) {
+ tests := []struct {
+ file string
+ pkgUpdates map[string]string
+ actual []string
+ }{
+ {
+ file: `package main
+import (
+ "github.com/apache/skywalking-go/plugins/core/operator"
+)`,
+ pkgUpdates: map[string]string{
+
"github.com/apache/skywalking-go/plugins/core/operator":
"github.com/apache/skywalking-go/agent/core/operator",
+ },
+ actual: []string{
+
"github.com/apache/skywalking-go/agent/core/operator",
+ },
+ },
+ {
+ file: `package main
+import (
+ "fmt"
+ "github.com/apache/skywalking-go/agent/core/operator"
+)`,
+ pkgUpdates: map[string]string{
+
"github.com/apache/skywalking-go/plugins/core/operator":
"github.com/apache/skywalking-go/agent/core/operator",
+ },
+ actual: []string{
+ "fmt",
+
"github.com/apache/skywalking-go/agent/core/operator",
+ },
+ },
+ }
+
+ for _, test := range tests {
+ f, err := decorator.ParseFile(nil, "test.go", test.file,
parser.ParseComments)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ChangePackageImportPath(f, test.pkgUpdates)
+ actual := make([]string, 0)
+ for _, i := range f.Imports {
+ actual = append(actual,
strings.TrimSuffix(strings.TrimPrefix(i.Path.Value, "\""), "\""))
+ }
+ if !reflect.DeepEqual(actual, test.actual) {
+ t.Fatalf("expect %v, actual %v", test.actual, actual)
+ }
+ }
+}
+
+func TestDeletePackageImports(t *testing.T) {
+ tests := []struct {
+ goCode string
+ validate func(result dst.Node) bool
+ isValue bool
+ }{
+ {
+ goCode: "test.Count(1)",
+ isValue: true,
+ validate: func(result dst.Node) bool {
+ call := result.(*dst.CallExpr)
+ return reflect.DeepEqual(call.Fun,
dst.NewIdent("Count"))
+ },
+ },
+ {
+ goCode: "[]test.Int{}",
+ isValue: true,
+ validate: func(result dst.Node) bool {
+ call := result.(*dst.CompositeLit)
+ return reflect.DeepEqual(call.Type,
&dst.ArrayType{
+ Elt: dst.NewIdent("Int"),
+ })
+ },
+ },
+ {
+ goCode: `type Object struct {
+ value test.Int
+}`,
+ isValue: false,
+ validate: func(result dst.Node) bool {
+ structType :=
result.(*dst.GenDecl).Specs[0].(*dst.TypeSpec).Type.(*dst.StructType)
+ return
reflect.DeepEqual(structType.Fields.List[0].Type, dst.NewIdent("Int"))
+ },
+ },
+ }
+
+ for i, test := range tests {
+ content := "import test \"testpackage/test\"\n"
+ if test.isValue {
+ content += "var val = " + test.goCode
+ } else {
+ content += test.goCode
+ }
+ decls := GoStringToDecls(content)
+
+ file := &dst.File{Name: dst.NewIdent("dst"), Decls: decls}
+ DeletePackageImports(file, "testpackage/test")
+ if len(file.Decls) != 1 {
+ t.Errorf("failure to delete package, current decl
count: %d", len(file.Decls))
+ }
+ var actualResult dst.Node
+ if test.isValue {
+ actualResult =
file.Decls[0].(*dst.GenDecl).Specs[0].(*dst.ValueSpec).Values[0]
+ } else {
+ actualResult = file.Decls[0]
+ }
+ if !test.validate(actualResult) {
+ t.Fatalf("validate %d error, real get result: %v", i,
actualResult)
+ }
+ }
+}
+
+func TestGenerateDSTFileContent(t *testing.T) {
+ tests := []struct {
+ fileContent string
+ debugInfo *DebugInfo
+ resultContent string
+ }{
+ {
+ fileContent: `package main
+
+import (
+ "fmt"
+)
+
+func main() {
+}
+`,
+ debugInfo: &DebugInfo{
+ FilePath: "/test/main.go",
+ Line: 10,
+ CheckOldLine: true,
+ },
+ resultContent: `package main
+
+import (
+ "fmt"
+)
+
+//line /test/main.go:10
+func main() {
+}
+`,
+ },
+ }
+
+ for i, test := range tests {
+ file, err := decorator.ParseFile(nil, "test.go",
test.fileContent, parser.ParseComments)
+ if err != nil {
+ t.Fatal(err)
+ }
+ content, err := GenerateDSTFileContent(file, test.debugInfo)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if content != test.resultContent {
+ t.Fatalf("case %d: expect %s, actual %s", i,
test.resultContent, content)
+ }
+ }
+}
+
+func TestImportAnalyzer(t *testing.T) {
+ tests := []struct {
+ imports []string
+ fieldsCode string
+ usedImports map[string]string
+ }{
+ {
+ imports: []string{`test1 "test/test1"`},
+ fieldsCode: `v1 test1.Int`,
+ usedImports: map[string]string{
+ "test1": "test/test1",
+ },
+ },
+ {
+ imports: []string{`test1 "test/test1"`},
+ fieldsCode: `v1 int`,
+ usedImports: map[string]string{},
+ },
+ {
+ imports: []string{`test1 "test/test1"`, `test2
"test/test2"`},
+ fieldsCode: `v1 []test2.Int`,
+ usedImports: map[string]string{
+ "test2": "test/test2",
+ },
+ },
+ }
+
+ for i, test := range tests {
+ content := ""
+ for _, imp := range test.imports {
+ content += "import " + imp + "\n"
+ }
+ content += fmt.Sprintf("func test(%s) {}", test.fieldsCode)
+
+ f := &dst.File{
+ Name: dst.NewIdent("test.go"),
+ Decls: GoStringToDecls(content),
+ }
+ analyzer := CreateImportAnalyzer()
+ analyzer.AnalyzeFileImports("test.go", f)
+ analyzer.AnalyzeNeedsImports("test.go",
f.Decls[len(f.Decls)-1].(*dst.FuncDecl).Type.Params)
+
+ if len(test.usedImports) != len(analyzer.usedImports) {
+ t.Fatalf("case %d: expect %d used imports, actual %d",
i, len(test.usedImports), len(analyzer.usedImports))
+ }
+ for name, path := range test.usedImports {
+ spec := analyzer.usedImports[name]
+ if spec == nil {
+ t.Fatalf("case %d: expect use %s, actual nil",
i, name)
+ }
+ if spec.Path.Value != fmt.Sprintf("%q", path) {
+ t.Fatalf("case %d: expect use %s, actual %s",
i, path, spec.Path.Value)
+ }
+ }
+ }
+}
diff --git a/tools/go-agent/tools/enhancement.go
b/tools/go-agent/tools/enhancement.go
index 927b06d..7b800fd 100644
--- a/tools/go-agent/tools/enhancement.go
+++ b/tools/go-agent/tools/enhancement.go
@@ -42,16 +42,41 @@ type PackagedParameterInfo struct {
PackageName string
}
+type FieldListType int
+
+const (
+ FieldListTypeParam FieldListType = iota
+ FieldListTypeResult
+ FieldListTypeRecv
+)
+
+func (f FieldListType) String() string {
+ switch f {
+ case FieldListTypeRecv:
+ return "recv"
+ case FieldListTypeParam:
+ return "param"
+ case FieldListTypeResult:
+ return "result"
+ }
+ return ""
+}
+
// EnhanceParameterNames enhance the parameter names if they are missing
-func EnhanceParameterNames(fields *dst.FieldList, isResult bool)
[]*ParameterInfo {
+func EnhanceParameterNames(fields *dst.FieldList, fieldType FieldListType)
[]*ParameterInfo {
if fields == nil {
return nil
}
result := make([]*ParameterInfo, 0)
for i, f := range fields.List {
- defineName := fmt.Sprintf("skywalking_param_%d", i)
- if isResult {
+ var defineName string
+ switch fieldType {
+ case FieldListTypeParam:
+ defineName = fmt.Sprintf("skywalking_param_%d", i)
+ case FieldListTypeResult:
defineName = fmt.Sprintf("skywalking_result_%d", i)
+ case FieldListTypeRecv:
+ defineName = fmt.Sprintf("skywalking_recv_%d", i)
}
if len(f.Names) == 0 {
f.Names = []*dst.Ident{{Name: defineName}}
@@ -73,8 +98,8 @@ func EnhanceParameterNames(fields *dst.FieldList, isResult
bool) []*ParameterInf
return result
}
-func EnhanceParameterNamesWithPackagePrefix(pkg string, fields *dst.FieldList,
isResult bool) []*PackagedParameterInfo {
- params := EnhanceParameterNames(fields, isResult)
+func EnhanceParameterNamesWithPackagePrefix(pkg string, fields *dst.FieldList,
fieldListType FieldListType) []*PackagedParameterInfo {
+ params := EnhanceParameterNames(fields, fieldListType)
result := make([]*PackagedParameterInfo, 0)
for _, p := range params {
result = append(result, &PackagedParameterInfo{ParameterInfo:
*p, PackageName: pkg})
diff --git a/tools/go-agent/tools/enhancement_test.go
b/tools/go-agent/tools/enhancement_test.go
new file mode 100644
index 0000000..ea9058d
--- /dev/null
+++ b/tools/go-agent/tools/enhancement_test.go
@@ -0,0 +1,102 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package tools
+
+import (
+ "testing"
+
+ "github.com/dave/dst"
+)
+
+func buildParameterValidateInfo(name, typeName, defaultValue string)
*ParameterInfo {
+ return &ParameterInfo{
+ Name: name,
+ TypeName: typeName,
+ DefaultValueAsString: defaultValue,
+ }
+}
+
+func TestEnhanceParameterNames(t *testing.T) {
+ tests := []struct {
+ funcCode string
+ recvs []*ParameterInfo
+ params []*ParameterInfo
+ results []*ParameterInfo
+ }{
+ {
+ funcCode: `func (*Example) Test(int) bool {
+ return false
+ }`,
+ recvs: []*ParameterInfo{
+ buildParameterValidateInfo("skywalking_recv_0",
"*Example", "nil"),
+ },
+ params: []*ParameterInfo{
+
buildParameterValidateInfo("skywalking_param_0", "int", "0"),
+ },
+ results: []*ParameterInfo{
+
buildParameterValidateInfo("skywalking_result_0", "bool", "false"),
+ },
+ },
+ {
+ funcCode: `func (e *Example) Test(i int) (b bool) {
+ return false
+}`,
+ recvs: []*ParameterInfo{
+ buildParameterValidateInfo("e", "*Example",
"nil"),
+ },
+ params: []*ParameterInfo{
+ buildParameterValidateInfo("i", "int", "0"),
+ },
+ results: []*ParameterInfo{
+ buildParameterValidateInfo("b", "bool",
"false"),
+ },
+ },
+ }
+
+ for i, test := range tests {
+ fun := GoStringToDecls(test.funcCode)[0].(*dst.FuncDecl)
+ var actualRecv, actualParams, actualResults []*ParameterInfo
+ if fun.Recv != nil {
+ actualRecv = EnhanceParameterNames(fun.Recv,
FieldListTypeRecv)
+ }
+ actualParams = EnhanceParameterNames(fun.Type.Params,
FieldListTypeParam)
+ actualResults = EnhanceParameterNames(fun.Type.Results,
FieldListTypeResult)
+
+ validateParameterInfo(t, i, FieldListTypeRecv, actualRecv,
test.recvs)
+ validateParameterInfo(t, i, FieldListTypeParam, actualParams,
test.params)
+ validateParameterInfo(t, i, FieldListTypeResult, actualResults,
test.results)
+ }
+}
+
+func validateParameterInfo(t *testing.T, inx int, flistType FieldListType,
actual, excepted []*ParameterInfo) {
+ if len(actual) != len(excepted) {
+ t.Errorf("case %d:%s: expected count %d , actual %d", inx,
flistType, len(excepted), len(actual))
+ }
+ for i, exp := range excepted {
+ act := actual[i]
+ if exp.Name != act.Name {
+ t.Errorf("case %d:%s: expected name %s , actual %s",
inx, flistType, exp.Name, act.Name)
+ }
+ if exp.TypeName != act.TypeName {
+ t.Errorf("case %d:%s: expected type %s , actual %s",
inx, flistType, exp.TypeName, act.TypeName)
+ }
+ if exp.DefaultValueAsString != act.DefaultValueAsString {
+ t.Errorf("case %d:%s: expected default value %s ,
actual %s", inx, flistType, exp.DefaultValueAsString, act.DefaultValueAsString)
+ }
+ }
+}