This is an automated email from the ASF dual-hosted git repository.

liuhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 4bee1d4  Add more unit tests (#80)
4bee1d4 is described below

commit 4bee1d47d62b44ddc03b3b3c72dc546b45a5cc0b
Author: mrproliu <[email protected]>
AuthorDate: Thu Jul 20 14:27:36 2023 +0800

    Add more unit tests (#80)
---
 tools/go-agent/instrument/logger/frameworks/zap.go |  14 +-
 tools/go-agent/instrument/logger/instrument.go     |  12 +-
 .../go-agent/instrument/plugins/enhance_method.go  |   6 +-
 tools/go-agent/instrument/runtime/instrument.go    |   4 +-
 tools/go-agent/tools/directive_test.go             | 112 ++++++++++
 tools/go-agent/tools/dst_test.go                   | 247 +++++++++++++++++++++
 tools/go-agent/tools/enhancement.go                |  35 ++-
 tools/go-agent/tools/enhancement_test.go           | 102 +++++++++
 8 files changed, 509 insertions(+), 23 deletions(-)

diff --git a/tools/go-agent/instrument/logger/frameworks/zap.go 
b/tools/go-agent/instrument/logger/frameworks/zap.go
index 6af67d3..8554943 100644
--- a/tools/go-agent/instrument/logger/frameworks/zap.go
+++ b/tools/go-agent/instrument/logger/frameworks/zap.go
@@ -150,8 +150,8 @@ func (z *Zap) CustomizedEnhance(path string, curFile 
*dst.File, cursor *dstutil.
                        n.Name.Name == "check" &&
                        n.Type.Results != nil && len(n.Type.Results.List) > 0 &&
                        
tools.GenerateTypeNameByExp(n.Type.Results.List[0].Type) == 
"*zapcore.CheckedEntry" {
-                       entryName := 
tools.EnhanceParameterNames(n.Type.Results, true)[0].Name
-                       recvName := tools.EnhanceParameterNames(n.Recv, 
true)[0].Name
+                       entryName := 
tools.EnhanceParameterNames(n.Type.Results, tools.FieldListTypeResult)[0].Name
+                       recvName := tools.EnhanceParameterNames(n.Recv, 
tools.FieldListTypeRecv)[0].Name
 
                        // init the zapcore variables
                        z.initFunction = 
tools.GoStringToDecls(fmt.Sprintf(`func initZapCore() {
@@ -171,9 +171,9 @@ zapcore.SWLogEnable = %s
                if n.Recv != nil && len(n.Recv.List) == 1 && n.Name.Name == 
"With" &&
                        n.Type.Params != nil && len(n.Type.Params.List) == 1 &&
                        tools.GenerateTypeNameByExp(n.Recv.List[0].Type) == 
"*Logger" && tools.GenerateTypeNameByExp(n.Type.Params.List[0].Type) == 
"[]Field" {
-                       recvs := tools.EnhanceParameterNames(n.Recv, false)
-                       parameters := 
tools.EnhanceParameterNames(n.Type.Params, false)
-                       results := tools.EnhanceParameterNames(n.Type.Results, 
true)
+                       recvs := tools.EnhanceParameterNames(n.Recv, 
tools.FieldListTypeRecv)
+                       parameters := 
tools.EnhanceParameterNames(n.Type.Params, tools.FieldListTypeParam)
+                       results := tools.EnhanceParameterNames(n.Type.Results, 
tools.FieldListTypeResult)
 
                        return z.enhanceMethod(n, fmt.Sprintf(`defer func() {if 
%s != nil { %s.SWFields = %sZap%s(%s, %s.SWFields) }}()`,
                                results[0].Name, results[0].Name, 
rewrite.StaticMethodPrefix, "KnownFieldFilter", parameters[0].Name, 
recvs[0].Name)), true
@@ -184,8 +184,8 @@ zapcore.SWLogEnable = %s
                        n.Name.Name == "Write" &&
                        n.Type.Params != nil && len(n.Type.Params.List) == 1 &&
                        tools.GenerateTypeNameByExp(n.Type.Params.List[0].Type) 
== "[]Field" {
-                       recvs := tools.EnhanceParameterNames(n.Recv, false)
-                       parameters := 
tools.EnhanceParameterNames(n.Type.Params, false)
+                       recvs := tools.EnhanceParameterNames(n.Recv, 
tools.FieldListTypeRecv)
+                       parameters := 
tools.EnhanceParameterNames(n.Type.Params, tools.FieldListTypeParam)
                        return z.enhanceMethod(n, fmt.Sprintf(`if %s != nil { 
%s = %sZapcore%s(%s, %s, %s.SWFields, %s.SWContext, %s.SWContextField, 
SWReporterEnable, SWLogEnable, SWReporterLabelKeys) }`,
                                recvs[0].Name, parameters[0].Name, 
rewrite.StaticMethodPrefix, "ReportLogFromZapEntry", recvs[0].Name,
                                parameters[0].Name, recvs[0].Name, 
recvs[0].Name, recvs[0].Name)), true
diff --git a/tools/go-agent/instrument/logger/instrument.go 
b/tools/go-agent/instrument/logger/instrument.go
index 38e9caa..bd05234 100644
--- a/tools/go-agent/instrument/logger/instrument.go
+++ b/tools/go-agent/instrument/logger/instrument.go
@@ -117,9 +117,9 @@ func (i *Instrument) addAutomaticBindFunc(path string, 
curFile dst.Node, fun *ds
                Results               []*tools.ParameterInfo
        }{
                AutomaticBindFuncName: generateFuncName,
-               Recvs:                 tools.EnhanceParameterNames(fun.Recv, 
false),
-               Parameters:            
tools.EnhanceParameterNames(fun.Type.Params, false),
-               Results:               
tools.EnhanceParameterNames(fun.Type.Results, true),
+               Recvs:                 tools.EnhanceParameterNames(fun.Recv, 
tools.FieldListTypeRecv),
+               Parameters:            
tools.EnhanceParameterNames(fun.Type.Params, tools.FieldListTypeParam),
+               Results:               
tools.EnhanceParameterNames(fun.Type.Results, tools.FieldListTypeResult),
        })
        importAnalyzer.AnalyzeNeedsImports(path, fun.Recv)
        importAnalyzer.AnalyzeNeedsImports(path, fun.Type.Params)
@@ -307,14 +307,14 @@ func (i *Instrument) addAutomaticFuncDelegators(f 
*dst.File, importDecl *dst.Gen
                        },
                }
 
-               for i, recv := range 
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Recv, false) 
{
+               for i, recv := range 
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Recv, 
tools.FieldListTypeRecv) {
                        delegatorFunc.Type.Params.List = 
append(delegatorFunc.Type.Params.List, &dst.Field{
                                Names: 
[]*dst.Ident{dst.NewIdent(fmt.Sprintf("recv_%d", i))},
                                Type:  &dst.StarExpr{X: recv.PackagedType()},
                        })
                }
 
-               for i, parameter := range 
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Type.Params, 
false) {
+               for i, parameter := range 
tools.EnhanceParameterNamesWithPackagePrefix(packageName, fun.Func.Type.Params, 
tools.FieldListTypeParam) {
                        packagedType := parameter.PackagedType()
                        // if the parameter is dynamic list, then change it to 
the array type
                        if el, ok := packagedType.(*dst.Ellipsis); ok {
@@ -326,7 +326,7 @@ func (i *Instrument) addAutomaticFuncDelegators(f 
*dst.File, importDecl *dst.Gen
                        })
                }
 
-               for i, result := range 
tools.EnhanceParameterNamesWithPackagePrefix(packageName, 
fun.Func.Type.Results, true) {
+               for i, result := range 
tools.EnhanceParameterNamesWithPackagePrefix(packageName, 
fun.Func.Type.Results, tools.FieldListTypeResult) {
                        delegatorFunc.Type.Params.List = 
append(delegatorFunc.Type.Params.List, &dst.Field{
                                Names: 
[]*dst.Ident{dst.NewIdent(fmt.Sprintf("ret_%d", i))},
                                Type:  &dst.StarExpr{X: result.PackagedType()},
diff --git a/tools/go-agent/instrument/plugins/enhance_method.go 
b/tools/go-agent/instrument/plugins/enhance_method.go
index 4fdf38a..8ebd85a 100644
--- a/tools/go-agent/instrument/plugins/enhance_method.go
+++ b/tools/go-agent/instrument/plugins/enhance_method.go
@@ -77,11 +77,11 @@ func NewMethodEnhance(inst instrument.Instrument, matcher 
*instrument.Point, f *
                packageName:           pkgName,
                InstrumentName:        inst.Name(),
                InterceptorDefineName: matcher.Interceptor,
-               Parameters:            
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Params, false),
-               Results:               
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Results, true),
+               Parameters:            
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Params, 
tools.FieldListTypeParam),
+               Results:               
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Type.Results, 
tools.FieldListTypeResult),
        }
        if f.Recv != nil {
-               enhance.Recvs = 
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Recv, false)
+               enhance.Recvs = 
tools.EnhanceParameterNamesWithPackagePrefix(pkgName, f.Recv, 
tools.FieldListTypeRecv)
        }
 
        importAnalyzer.AnalyzeNeedsImports(path, f.Type.Params)
diff --git a/tools/go-agent/instrument/runtime/instrument.go 
b/tools/go-agent/instrument/runtime/instrument.go
index 08d2073..967ae7c 100644
--- a/tools/go-agent/instrument/runtime/instrument.go
+++ b/tools/go-agent/instrument/runtime/instrument.go
@@ -69,8 +69,8 @@ func (r *Instrument) FilterAndEdit(path string, curFile 
*dst.File, cursor *dstut
                if len(n.Type.Params.List) != 3 {
                        return false
                }
-               parameters := tools.EnhanceParameterNames(n.Type.Params, false)
-               results := tools.EnhanceParameterNames(n.Type.Results, true)
+               parameters := tools.EnhanceParameterNames(n.Type.Params, 
tools.FieldListTypeParam)
+               results := tools.EnhanceParameterNames(n.Type.Results, 
tools.FieldListTypeResult)
 
                tools.InsertStmtsBeforeBody(n.Body, `defer func() {
        {{(index .Results 0).Name}}.{{.TLSField}} = goroutineChange({{(index 
.Parameters 1).Name}}.{{.TLSField}})
diff --git a/tools/go-agent/tools/directive_test.go 
b/tools/go-agent/tools/directive_test.go
new file mode 100644
index 0000000..8d25627
--- /dev/null
+++ b/tools/go-agent/tools/directive_test.go
@@ -0,0 +1,112 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package tools
+
+import (
+       "testing"
+
+       "github.com/apache/skywalking-go/tools/go-agent/instrument/consts"
+)
+
+func TestContainsDirective(t *testing.T) {
+       var tests = []struct {
+               goCode    string
+               directive string
+               contains  bool
+       }{
+               {
+                       goCode: `//skywalking:nocopy
+func test() {}`,
+                       directive: consts.DirecitveNoCopy,
+                       contains:  true,
+               },
+               {
+                       goCode: `// test method
+//skywalking:nocopy
+// test method
+func test1() {}`,
+                       directive: consts.DirecitveNoCopy,
+                       contains:  true,
+               },
+               {
+                       goCode: `func test1() {}
+//skywalking:nocopy
+`,
+                       directive: consts.DirecitveNoCopy,
+                       contains:  false,
+               },
+               {
+                       goCode: `// skywalking:nocopy
+func test2() {}`,
+                       directive: consts.DirecitveNoCopy,
+                       contains:  false,
+               },
+       }
+
+       for _, test := range tests {
+               decls := GoStringToDecls(test.goCode)
+               contains := ContainsDirective(decls[0], test.directive)
+               if contains != test.contains {
+                       t.Errorf("ContainsDirective(%s, %s) = %v, excepted %v", 
test.goCode, test.directive, contains, test.contains)
+               }
+       }
+}
+
+func TestFindDirective(t *testing.T) {
+       var tests = []struct {
+               goCode    string
+               directive string
+               found     string
+       }{
+               {
+                       goCode: `//skywalking:nocopy
+func test() {}`,
+                       directive: consts.DirecitveNoCopy,
+                       found:     "//skywalking:nocopy",
+               },
+               {
+                       goCode: `// test method
+//skywalking:nocopy
+// test method
+func test1() {}`,
+                       directive: consts.DirecitveNoCopy,
+                       found:     "//skywalking:nocopy",
+               },
+               {
+                       goCode: `func test1() {}
+//skywalking:nocopy
+`,
+                       directive: consts.DirecitveNoCopy,
+                       found:     "",
+               },
+               {
+                       goCode: `//skywalking:native test method
+func test2() {}`,
+                       directive: consts.DirectiveNative,
+                       found:     "//skywalking:native test method",
+               },
+       }
+
+       for _, test := range tests {
+               decls := GoStringToDecls(test.goCode)
+               found := FindDirective(decls[0], test.directive)
+               if found != test.found {
+                       t.Errorf("FindDirective(%s, %s) = %v, excepted %v", 
test.goCode, test.directive, found, test.found)
+               }
+       }
+}
diff --git a/tools/go-agent/tools/dst_test.go b/tools/go-agent/tools/dst_test.go
new file mode 100644
index 0000000..84b0c65
--- /dev/null
+++ b/tools/go-agent/tools/dst_test.go
@@ -0,0 +1,247 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package tools
+
+import (
+       "fmt"
+       "go/parser"
+       "reflect"
+       "strings"
+       "testing"
+
+       "github.com/dave/dst"
+       "github.com/dave/dst/decorator"
+)
+
+func TestChangePackageImportPath(t *testing.T) {
+       tests := []struct {
+               file       string
+               pkgUpdates map[string]string
+               actual     []string
+       }{
+               {
+                       file: `package main
+import (
+       "github.com/apache/skywalking-go/plugins/core/operator"
+)`,
+                       pkgUpdates: map[string]string{
+                               
"github.com/apache/skywalking-go/plugins/core/operator": 
"github.com/apache/skywalking-go/agent/core/operator",
+                       },
+                       actual: []string{
+                               
"github.com/apache/skywalking-go/agent/core/operator",
+                       },
+               },
+               {
+                       file: `package main
+import (
+       "fmt"
+       "github.com/apache/skywalking-go/agent/core/operator"
+)`,
+                       pkgUpdates: map[string]string{
+                               
"github.com/apache/skywalking-go/plugins/core/operator": 
"github.com/apache/skywalking-go/agent/core/operator",
+                       },
+                       actual: []string{
+                               "fmt",
+                               
"github.com/apache/skywalking-go/agent/core/operator",
+                       },
+               },
+       }
+
+       for _, test := range tests {
+               f, err := decorator.ParseFile(nil, "test.go", test.file, 
parser.ParseComments)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               ChangePackageImportPath(f, test.pkgUpdates)
+               actual := make([]string, 0)
+               for _, i := range f.Imports {
+                       actual = append(actual, 
strings.TrimSuffix(strings.TrimPrefix(i.Path.Value, "\""), "\""))
+               }
+               if !reflect.DeepEqual(actual, test.actual) {
+                       t.Fatalf("expect %v, actual %v", test.actual, actual)
+               }
+       }
+}
+
+func TestDeletePackageImports(t *testing.T) {
+       tests := []struct {
+               goCode   string
+               validate func(result dst.Node) bool
+               isValue  bool
+       }{
+               {
+                       goCode:  "test.Count(1)",
+                       isValue: true,
+                       validate: func(result dst.Node) bool {
+                               call := result.(*dst.CallExpr)
+                               return reflect.DeepEqual(call.Fun, 
dst.NewIdent("Count"))
+                       },
+               },
+               {
+                       goCode:  "[]test.Int{}",
+                       isValue: true,
+                       validate: func(result dst.Node) bool {
+                               call := result.(*dst.CompositeLit)
+                               return reflect.DeepEqual(call.Type, 
&dst.ArrayType{
+                                       Elt: dst.NewIdent("Int"),
+                               })
+                       },
+               },
+               {
+                       goCode: `type Object struct {
+       value test.Int
+}`,
+                       isValue: false,
+                       validate: func(result dst.Node) bool {
+                               structType := 
result.(*dst.GenDecl).Specs[0].(*dst.TypeSpec).Type.(*dst.StructType)
+                               return 
reflect.DeepEqual(structType.Fields.List[0].Type, dst.NewIdent("Int"))
+                       },
+               },
+       }
+
+       for i, test := range tests {
+               content := "import test \"testpackage/test\"\n"
+               if test.isValue {
+                       content += "var val = " + test.goCode
+               } else {
+                       content += test.goCode
+               }
+               decls := GoStringToDecls(content)
+
+               file := &dst.File{Name: dst.NewIdent("dst"), Decls: decls}
+               DeletePackageImports(file, "testpackage/test")
+               if len(file.Decls) != 1 {
+                       t.Errorf("failure to delete package, current decl 
count: %d", len(file.Decls))
+               }
+               var actualResult dst.Node
+               if test.isValue {
+                       actualResult = 
file.Decls[0].(*dst.GenDecl).Specs[0].(*dst.ValueSpec).Values[0]
+               } else {
+                       actualResult = file.Decls[0]
+               }
+               if !test.validate(actualResult) {
+                       t.Fatalf("validate %d error, real get result: %v", i, 
actualResult)
+               }
+       }
+}
+
+func TestGenerateDSTFileContent(t *testing.T) {
+       tests := []struct {
+               fileContent   string
+               debugInfo     *DebugInfo
+               resultContent string
+       }{
+               {
+                       fileContent: `package main
+
+import (
+       "fmt"
+)
+
+func main() {
+}
+`,
+                       debugInfo: &DebugInfo{
+                               FilePath:     "/test/main.go",
+                               Line:         10,
+                               CheckOldLine: true,
+                       },
+                       resultContent: `package main
+
+import (
+       "fmt"
+)
+
+//line /test/main.go:10
+func main() {
+}
+`,
+               },
+       }
+
+       for i, test := range tests {
+               file, err := decorator.ParseFile(nil, "test.go", 
test.fileContent, parser.ParseComments)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               content, err := GenerateDSTFileContent(file, test.debugInfo)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if content != test.resultContent {
+                       t.Fatalf("case %d: expect %s, actual %s", i, 
test.resultContent, content)
+               }
+       }
+}
+
+func TestImportAnalyzer(t *testing.T) {
+       tests := []struct {
+               imports     []string
+               fieldsCode  string
+               usedImports map[string]string
+       }{
+               {
+                       imports:    []string{`test1 "test/test1"`},
+                       fieldsCode: `v1 test1.Int`,
+                       usedImports: map[string]string{
+                               "test1": "test/test1",
+                       },
+               },
+               {
+                       imports:     []string{`test1 "test/test1"`},
+                       fieldsCode:  `v1 int`,
+                       usedImports: map[string]string{},
+               },
+               {
+                       imports:    []string{`test1 "test/test1"`, `test2 
"test/test2"`},
+                       fieldsCode: `v1 []test2.Int`,
+                       usedImports: map[string]string{
+                               "test2": "test/test2",
+                       },
+               },
+       }
+
+       for i, test := range tests {
+               content := ""
+               for _, imp := range test.imports {
+                       content += "import " + imp + "\n"
+               }
+               content += fmt.Sprintf("func test(%s) {}", test.fieldsCode)
+
+               f := &dst.File{
+                       Name:  dst.NewIdent("test.go"),
+                       Decls: GoStringToDecls(content),
+               }
+               analyzer := CreateImportAnalyzer()
+               analyzer.AnalyzeFileImports("test.go", f)
+               analyzer.AnalyzeNeedsImports("test.go", 
f.Decls[len(f.Decls)-1].(*dst.FuncDecl).Type.Params)
+
+               if len(test.usedImports) != len(analyzer.usedImports) {
+                       t.Fatalf("case %d: expect %d used imports, actual %d", 
i, len(test.usedImports), len(analyzer.usedImports))
+               }
+               for name, path := range test.usedImports {
+                       spec := analyzer.usedImports[name]
+                       if spec == nil {
+                               t.Fatalf("case %d: expect use %s, actual nil", 
i, name)
+                       }
+                       if spec.Path.Value != fmt.Sprintf("%q", path) {
+                               t.Fatalf("case %d: expect use %s, actual %s", 
i, path, spec.Path.Value)
+                       }
+               }
+       }
+}
diff --git a/tools/go-agent/tools/enhancement.go 
b/tools/go-agent/tools/enhancement.go
index 927b06d..7b800fd 100644
--- a/tools/go-agent/tools/enhancement.go
+++ b/tools/go-agent/tools/enhancement.go
@@ -42,16 +42,41 @@ type PackagedParameterInfo struct {
        PackageName string
 }
 
+type FieldListType int
+
+const (
+       FieldListTypeParam FieldListType = iota
+       FieldListTypeResult
+       FieldListTypeRecv
+)
+
+func (f FieldListType) String() string {
+       switch f {
+       case FieldListTypeRecv:
+               return "recv"
+       case FieldListTypeParam:
+               return "param"
+       case FieldListTypeResult:
+               return "result"
+       }
+       return ""
+}
+
 // EnhanceParameterNames enhance the parameter names if they are missing
-func EnhanceParameterNames(fields *dst.FieldList, isResult bool) 
[]*ParameterInfo {
+func EnhanceParameterNames(fields *dst.FieldList, fieldType FieldListType) 
[]*ParameterInfo {
        if fields == nil {
                return nil
        }
        result := make([]*ParameterInfo, 0)
        for i, f := range fields.List {
-               defineName := fmt.Sprintf("skywalking_param_%d", i)
-               if isResult {
+               var defineName string
+               switch fieldType {
+               case FieldListTypeParam:
+                       defineName = fmt.Sprintf("skywalking_param_%d", i)
+               case FieldListTypeResult:
                        defineName = fmt.Sprintf("skywalking_result_%d", i)
+               case FieldListTypeRecv:
+                       defineName = fmt.Sprintf("skywalking_recv_%d", i)
                }
                if len(f.Names) == 0 {
                        f.Names = []*dst.Ident{{Name: defineName}}
@@ -73,8 +98,8 @@ func EnhanceParameterNames(fields *dst.FieldList, isResult 
bool) []*ParameterInf
        return result
 }
 
-func EnhanceParameterNamesWithPackagePrefix(pkg string, fields *dst.FieldList, 
isResult bool) []*PackagedParameterInfo {
-       params := EnhanceParameterNames(fields, isResult)
+func EnhanceParameterNamesWithPackagePrefix(pkg string, fields *dst.FieldList, 
fieldListType FieldListType) []*PackagedParameterInfo {
+       params := EnhanceParameterNames(fields, fieldListType)
        result := make([]*PackagedParameterInfo, 0)
        for _, p := range params {
                result = append(result, &PackagedParameterInfo{ParameterInfo: 
*p, PackageName: pkg})
diff --git a/tools/go-agent/tools/enhancement_test.go 
b/tools/go-agent/tools/enhancement_test.go
new file mode 100644
index 0000000..ea9058d
--- /dev/null
+++ b/tools/go-agent/tools/enhancement_test.go
@@ -0,0 +1,102 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package tools
+
+import (
+       "testing"
+
+       "github.com/dave/dst"
+)
+
+func buildParameterValidateInfo(name, typeName, defaultValue string) 
*ParameterInfo {
+       return &ParameterInfo{
+               Name:                 name,
+               TypeName:             typeName,
+               DefaultValueAsString: defaultValue,
+       }
+}
+
+func TestEnhanceParameterNames(t *testing.T) {
+       tests := []struct {
+               funcCode string
+               recvs    []*ParameterInfo
+               params   []*ParameterInfo
+               results  []*ParameterInfo
+       }{
+               {
+                       funcCode: `func (*Example) Test(int) bool {
+                               return false
+                       }`,
+                       recvs: []*ParameterInfo{
+                               buildParameterValidateInfo("skywalking_recv_0", 
"*Example", "nil"),
+                       },
+                       params: []*ParameterInfo{
+                               
buildParameterValidateInfo("skywalking_param_0", "int", "0"),
+                       },
+                       results: []*ParameterInfo{
+                               
buildParameterValidateInfo("skywalking_result_0", "bool", "false"),
+                       },
+               },
+               {
+                       funcCode: `func (e *Example) Test(i int) (b bool) {
+                               return false
+}`,
+                       recvs: []*ParameterInfo{
+                               buildParameterValidateInfo("e", "*Example", 
"nil"),
+                       },
+                       params: []*ParameterInfo{
+                               buildParameterValidateInfo("i", "int", "0"),
+                       },
+                       results: []*ParameterInfo{
+                               buildParameterValidateInfo("b", "bool", 
"false"),
+                       },
+               },
+       }
+
+       for i, test := range tests {
+               fun := GoStringToDecls(test.funcCode)[0].(*dst.FuncDecl)
+               var actualRecv, actualParams, actualResults []*ParameterInfo
+               if fun.Recv != nil {
+                       actualRecv = EnhanceParameterNames(fun.Recv, 
FieldListTypeRecv)
+               }
+               actualParams = EnhanceParameterNames(fun.Type.Params, 
FieldListTypeParam)
+               actualResults = EnhanceParameterNames(fun.Type.Results, 
FieldListTypeResult)
+
+               validateParameterInfo(t, i, FieldListTypeRecv, actualRecv, 
test.recvs)
+               validateParameterInfo(t, i, FieldListTypeParam, actualParams, 
test.params)
+               validateParameterInfo(t, i, FieldListTypeResult, actualResults, 
test.results)
+       }
+}
+
+func validateParameterInfo(t *testing.T, inx int, flistType FieldListType, 
actual, excepted []*ParameterInfo) {
+       if len(actual) != len(excepted) {
+               t.Errorf("case %d:%s: expected count %d , actual %d", inx, 
flistType, len(excepted), len(actual))
+       }
+       for i, exp := range excepted {
+               act := actual[i]
+               if exp.Name != act.Name {
+                       t.Errorf("case %d:%s: expected name %s , actual %s", 
inx, flistType, exp.Name, act.Name)
+               }
+               if exp.TypeName != act.TypeName {
+                       t.Errorf("case %d:%s: expected type %s , actual %s", 
inx, flistType, exp.TypeName, act.TypeName)
+               }
+               if exp.DefaultValueAsString != act.DefaultValueAsString {
+                       t.Errorf("case %d:%s: expected default value %s , 
actual %s", inx, flistType, exp.DefaultValueAsString, act.DefaultValueAsString)
+               }
+       }
+}

Reply via email to