This is an automated email from the ASF dual-hosted git repository.

juzhiyuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix-dashboard.git


The following commit(s) were added to refs/heads/master by this push:
     new f938c0c  feat: add a unit test for consumer and remove implicit init 
(#859)
f938c0c is described below

commit f938c0cb1c142b03082936e6560c195c8e784fad
Author: Vinci Xu <[email protected]>
AuthorDate: Fri Dec 11 16:50:12 2020 +0800

    feat: add a unit test for consumer and remove implicit init (#859)
    
    * feat: add a unit test for consumer and remove implicit init
    
    * fix: add implict init function to compatible integration tests
    
    * chore: add other cosumer unit test and refacotor some code
    
    * fix: remove intergration instead of unit test
    
    * fix: add EOL for file
    
    * chore: use sub test to run table test
    
    * chore: test desc
    
    * chore: test desc
    
    Co-authored-by: Wen Ming <[email protected]>
---
 api/conf/conf.go                               |   4 +
 api/filter/logging_test.go                     |  34 +-
 api/internal/core/store/store.go               |   2 +-
 api/internal/core/store/store_mock.go          |  64 +++
 api/internal/handler/consumer/consumer.go      |  52 +--
 api/internal/handler/consumer/consumer_test.go | 597 ++++++++++++++++---------
 api/log/zap.go                                 |   5 +-
 7 files changed, 483 insertions(+), 275 deletions(-)

diff --git a/api/conf/conf.go b/api/conf/conf.go
index b34c339..6ea9af0 100644
--- a/api/conf/conf.go
+++ b/api/conf/conf.go
@@ -101,7 +101,11 @@ type Config struct {
        Authentication Authentication
 }
 
+// TODO: it is just for integration tests, we should call "InitLog" explicitly 
when remove all handler's integration tests
 func init() {
+       InitConf()
+}
+func InitConf() {
        //go test
        if workDir := os.Getenv("APISIX_API_WORKDIR"); workDir != "" {
                WorkDir = workDir
diff --git a/api/filter/logging_test.go b/api/filter/logging_test.go
index 087d8b0..688befa 100644
--- a/api/filter/logging_test.go
+++ b/api/filter/logging_test.go
@@ -17,30 +17,30 @@
 package filter
 
 import (
-        "net/http"
-        "net/http/httptest"
-        "testing"
+       "net/http"
+       "net/http/httptest"
+       "testing"
 
-        "github.com/gin-gonic/gin"
-        "github.com/stretchr/testify/assert"
+       "github.com/gin-gonic/gin"
+       "github.com/stretchr/testify/assert"
 
-        "github.com/apisix/manager-api/log"
+       "github.com/apisix/manager-api/log"
 )
 
 func performRequest(r http.Handler, method, path string) 
*httptest.ResponseRecorder {
-        req := httptest.NewRequest(method, path, nil)
-        w := httptest.NewRecorder()
-        r.ServeHTTP(w, req)
-        return w
+       req := httptest.NewRequest(method, path, nil)
+       w := httptest.NewRecorder()
+       r.ServeHTTP(w, req)
+       return w
 }
 
 func TestRequestLogHandler(t *testing.T) {
-        r := gin.New()
-        logger := log.GetLogger(log.AccessLog)
-        r.Use(RequestLogHandler(logger))
-        r.GET("/", func(c *gin.Context) {
-        })
+       r := gin.New()
+       logger := log.GetLogger(log.AccessLog)
+       r.Use(RequestLogHandler(logger))
+       r.GET("/", func(c *gin.Context) {
+       })
 
-        w := performRequest(r, "GET", "/")
-        assert.Equal(t, 200, w.Code)
+       w := performRequest(r, "GET", "/")
+       assert.Equal(t, 200, w.Code)
 }
diff --git a/api/internal/core/store/store.go b/api/internal/core/store/store.go
index d9c7e36..ee3fa50 100644
--- a/api/internal/core/store/store.go
+++ b/api/internal/core/store/store.go
@@ -38,7 +38,7 @@ type Interface interface {
        Get(key string) (interface{}, error)
        List(input ListInput) (*ListOutput, error)
        Create(ctx context.Context, obj interface{}) error
-       Update(ctx context.Context, obj interface{}, createOnFail bool) error
+       Update(ctx context.Context, obj interface{}, createIfNotExist bool) 
error
        BatchDelete(ctx context.Context, keys []string) error
 }
 
diff --git a/api/internal/core/store/store_mock.go 
b/api/internal/core/store/store_mock.go
new file mode 100644
index 0000000..258f083
--- /dev/null
+++ b/api/internal/core/store/store_mock.go
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package store
+
+import (
+       "context"
+       "github.com/stretchr/testify/mock"
+)
+
+type MockInterface struct {
+       mock.Mock
+}
+
+func (m *MockInterface) Get(key string) (interface{}, error) {
+       ret := m.Mock.Called(key)
+       return ret.Get(0), ret.Error(1)
+}
+
+func (m *MockInterface) List(input ListInput) (*ListOutput, error) {
+       ret := m.Called(input)
+
+       var (
+               r0 *ListOutput
+               r1 error
+       )
+
+       if rf, ok := ret.Get(0).(func(ListInput) *ListOutput); ok {
+               r0 = rf(input)
+       } else {
+               r0 = ret.Get(0).(*ListOutput)
+       }
+       r1 = ret.Error(1)
+
+       return r0, r1
+}
+
+func (m *MockInterface) Create(ctx context.Context, obj interface{}) error {
+       ret := m.Mock.Called(ctx, obj)
+       return ret.Error(0)
+}
+
+func (m *MockInterface) Update(ctx context.Context, obj interface{}, 
createOnFail bool) error {
+       ret := m.Mock.Called(ctx, obj, createOnFail)
+       return ret.Error(0)
+}
+
+func (m *MockInterface) BatchDelete(ctx context.Context, keys []string) error {
+       ret := m.Mock.Called(ctx, keys)
+       return ret.Error(0)
+}
diff --git a/api/internal/handler/consumer/consumer.go 
b/api/internal/handler/consumer/consumer.go
index a5854c5..31b7af3 100644
--- a/api/internal/handler/consumer/consumer.go
+++ b/api/internal/handler/consumer/consumer.go
@@ -17,21 +17,17 @@
 package consumer
 
 import (
-       "fmt"
-       "net/http"
        "reflect"
        "strings"
 
        "github.com/gin-gonic/gin"
        "github.com/shiningrush/droplet"
-       "github.com/shiningrush/droplet/data"
        "github.com/shiningrush/droplet/wrapper"
        wgin "github.com/shiningrush/droplet/wrapper/gin"
 
        "github.com/apisix/manager-api/internal/core/entity"
        "github.com/apisix/manager-api/internal/core/store"
        "github.com/apisix/manager-api/internal/handler"
-       "github.com/apisix/manager-api/internal/utils"
 )
 
 type Handler struct {
@@ -56,7 +52,7 @@ func (h *Handler) ApplyRoute(r *gin.Engine) {
        r.PUT("/apisix/admin/consumers", wgin.Wraps(h.Update,
                wrapper.InputType(reflect.TypeOf(UpdateInput{}))))
        r.DELETE("/apisix/admin/consumers/:usernames", wgin.Wraps(h.BatchDelete,
-               wrapper.InputType(reflect.TypeOf(BatchDelete{}))))
+               wrapper.InputType(reflect.TypeOf(BatchDeleteInput{}))))
 }
 
 type GetInput struct {
@@ -134,19 +130,9 @@ func (h *Handler) List(c droplet.Context) (interface{}, 
error) {
 
 func (h *Handler) Create(c droplet.Context) (interface{}, error) {
        input := c.Input().(*entity.Consumer)
-       if input.ID != nil && utils.InterfaceToString(input.ID) != 
input.Username {
-               return &data.SpecCodeResponse{StatusCode: 
http.StatusBadRequest},
-                       fmt.Errorf("consumer's id and username must be a same 
value")
-       }
        input.ID = input.Username
 
-       if _, ok := input.Plugins["jwt-auth"]; ok {
-               jwt := input.Plugins["jwt-auth"].(map[string]interface{})
-               jwt["exp"] = 86400
-
-               input.Plugins["jwt-auth"] = jwt
-       }
-
+       ensurePluginsDefValue(input.Plugins)
        if err := h.consumerStore.Create(c.Context(), input); err != nil {
                return handler.SpecCodeResponse(err), err
        }
@@ -161,42 +147,34 @@ type UpdateInput struct {
 
 func (h *Handler) Update(c droplet.Context) (interface{}, error) {
        input := c.Input().(*UpdateInput)
-       if input.ID != nil && utils.InterfaceToString(input.ID) != 
input.Username {
-               return &data.SpecCodeResponse{StatusCode: 
http.StatusBadRequest},
-                       fmt.Errorf("consumer's id and username must be a same 
value")
-       }
        if input.Username != "" {
                input.Consumer.Username = input.Username
        }
        input.Consumer.ID = input.Consumer.Username
-
-       if _, ok := input.Consumer.Plugins["jwt-auth"]; ok {
-               jwt := 
input.Consumer.Plugins["jwt-auth"].(map[string]interface{})
-               jwt["exp"] = 86400
-
-               input.Consumer.Plugins["jwt-auth"] = jwt
-       }
+       ensurePluginsDefValue(input.Plugins)
 
        if err := h.consumerStore.Update(c.Context(), &input.Consumer, true); 
err != nil {
-               //if not exists, create
-               if err.Error() == fmt.Sprintf("key: %s is not found", 
input.Username) {
-                       if err := h.consumerStore.Create(c.Context(), 
&input.Consumer); err != nil {
-                               return handler.SpecCodeResponse(err), err
-                       }
-               } else {
-                       return handler.SpecCodeResponse(err), err
-               }
+               return handler.SpecCodeResponse(err), err
        }
 
        return nil, nil
 }
 
-type BatchDelete struct {
+func ensurePluginsDefValue(plugins map[string]interface{}) {
+       if plugins["jwt-auth"] != nil {
+               jwtAuth, ok := plugins["jwt-auth"].(map[string]interface{})
+               if ok && jwtAuth["exp"] == nil {
+                       jwtAuth["exp"] = 86400
+               }
+       }
+}
+
+type BatchDeleteInput struct {
        UserNames string `auto_read:"usernames,path"`
 }
 
 func (h *Handler) BatchDelete(c droplet.Context) (interface{}, error) {
-       input := c.Input().(*BatchDelete)
+       input := c.Input().(*BatchDeleteInput)
 
        if err := h.consumerStore.BatchDelete(c.Context(), 
strings.Split(input.UserNames, ",")); err != nil {
                return handler.SpecCodeResponse(err), err
diff --git a/api/internal/handler/consumer/consumer_test.go 
b/api/internal/handler/consumer/consumer_test.go
index ba02f28..6a82154 100644
--- a/api/internal/handler/consumer/consumer_test.go
+++ b/api/internal/handler/consumer/consumer_test.go
@@ -18,232 +18,391 @@
 package consumer
 
 import (
-       "encoding/json"
-
-       "testing"
-       "time"
-
-       "github.com/shiningrush/droplet"
-       "github.com/stretchr/testify/assert"
-
-       "github.com/apisix/manager-api/conf"
+       "context"
+       "fmt"
        "github.com/apisix/manager-api/internal/core/entity"
-       "github.com/apisix/manager-api/internal/core/storage"
        "github.com/apisix/manager-api/internal/core/store"
+       "github.com/shiningrush/droplet"
+       "github.com/shiningrush/droplet/data"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/mock"
+       "net/http"
+       "testing"
 )
 
-func TestConsumer(t *testing.T) {
-       // init
-       err := storage.InitETCDClient(conf.ETCDConfig)
-       assert.Nil(t, err)
-       err = store.InitStores()
-       assert.Nil(t, err)
+func TestHandler_Get(t *testing.T) {
+       tests := []struct {
+               caseDesc   string
+               giveInput  *GetInput
+               giveRet    interface{}
+               giveErr    error
+               wantErr    error
+               wantGetKey string
+               wantRet    interface{}
+       }{
+               {
+                       caseDesc:   "normal",
+                       giveInput:  &GetInput{Username: "test"},
+                       wantGetKey: "test",
+                       giveRet:    "hello",
+                       wantRet:    "hello",
+               },
+               {
+                       caseDesc:   "store get failed",
+                       giveInput:  &GetInput{Username: "failed key"},
+                       wantGetKey: "failed key",
+                       giveErr:    fmt.Errorf("get failed"),
+                       wantErr:    fmt.Errorf("get failed"),
+                       wantRet: &data.SpecCodeResponse{
+                               StatusCode: http.StatusInternalServerError,
+                       },
+               },
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.caseDesc, func(t *testing.T) {
+                       getCalled := true
+                       mStore := &store.MockInterface{}
+                       mStore.On("Get", mock.Anything).Run(func(args 
mock.Arguments) {
+                               getCalled = true
+                               assert.Equal(t, tc.wantGetKey, args.Get(0))
+                       }).Return(tc.giveRet, tc.giveErr)
+
+                       h := Handler{consumerStore: mStore}
+                       ctx := droplet.NewContext()
+                       ctx.SetInput(tc.giveInput)
+                       ret, err := h.Get(ctx)
+                       assert.True(t, getCalled)
+                       assert.Equal(t, tc.wantRet, ret)
+                       assert.Equal(t, tc.wantErr, err)
+               })
+       }
+}
+
+func TestHandler_List(t *testing.T) {
+       tests := []struct {
+               caseDesc  string
+               giveInput *ListInput
+               giveData  []*entity.Consumer
+               giveErr   error
+               wantErr   error
+               wantInput store.ListInput
+               wantRet   interface{}
+       }{
+               {
+                       caseDesc: "list all condition",
+                       giveInput: &ListInput{
+                               Username: "testUser",
+                               Pagination: store.Pagination{
+                                       PageSize:   10,
+                                       PageNumber: 10,
+                               },
+                       },
+                       wantInput: store.ListInput{
+                               PageSize:   10,
+                               PageNumber: 10,
+                       },
+                       giveData: []*entity.Consumer{
+                               {Username: "user1"},
+                               {Username: "testUser"},
+                               {Username: "iam-testUser"},
+                               {Username: "testUser-is-me"},
+                       },
+                       wantRet: &store.ListOutput{
+                               Rows: []interface{}{
+                                       &entity.Consumer{Username: "testUser"},
+                                       &entity.Consumer{Username: 
"iam-testUser"},
+                                       &entity.Consumer{Username: 
"testUser-is-me"},
+                               },
+                               TotalSize: 3,
+                       },
+               },
+               {
+                       caseDesc: "store list failed",
+                       giveInput: &ListInput{
+                               Username: "testUser",
+                               Pagination: store.Pagination{
+                                       PageSize:   10,
+                                       PageNumber: 10,
+                               },
+                       },
+                       wantInput: store.ListInput{
+                               PageSize:   10,
+                               PageNumber: 10,
+                       },
+                       giveData: []*entity.Consumer{},
+                       giveErr:  fmt.Errorf("list failed"),
+                       wantErr:  fmt.Errorf("list failed"),
+               },
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.caseDesc, func(t *testing.T) {
+                       getCalled := true
+                       mStore := &store.MockInterface{}
+                       mStore.On("List", mock.Anything).Run(func(args 
mock.Arguments) {
+                               getCalled = true
+                               input := args.Get(0).(store.ListInput)
+                               assert.Equal(t, tc.wantInput.PageSize, 
input.PageSize)
+                               assert.Equal(t, tc.wantInput.PageNumber, 
input.PageNumber)
+                       }).Return(func(input store.ListInput) *store.ListOutput 
{
+                               var returnData []interface{}
+                               for _, c := range tc.giveData {
+                                       if input.Predicate(c) {
+                                               returnData = append(returnData, 
c)
+                                       }
+                               }
+                               return &store.ListOutput{
+                                       Rows:      returnData,
+                                       TotalSize: len(returnData),
+                               }
+                       }, tc.giveErr)
 
-       handler := &Handler{
-               consumerStore: store.GetStore(store.HubKeyConsumer),
+                       h := Handler{consumerStore: mStore}
+                       ctx := droplet.NewContext()
+                       ctx.SetInput(tc.giveInput)
+                       ret, err := h.List(ctx)
+                       assert.True(t, getCalled)
+                       assert.Equal(t, tc.wantRet, ret)
+                       assert.Equal(t, tc.wantErr, err)
+               })
        }
-       assert.NotNil(t, handler)
-
-       //create consumer
-       ctx := droplet.NewContext()
-       consumer := &entity.Consumer{}
-       reqBody := `{
-      "username": "jack",
-      "plugins": {
-          "limit-count": {
-              "count": 2,
-              "time_window": 60,
-              "rejected_code": 503,
-              "key": "remote_addr"
-          }
-      },
-    "desc": "test description"
-  }`
-       err = json.Unmarshal([]byte(reqBody), consumer)
-       assert.Nil(t, err)
-       ctx.SetInput(consumer)
-       _, err = handler.Create(ctx)
-       assert.Nil(t, err)
-
-       //create consumer 2
-       consumer2 := &entity.Consumer{}
-       reqBody = `{
-               "username": "pony",
-               "plugins": {
-                 "limit-count": {
-                     "count": 2,
-                     "time_window": 60,
-                     "rejected_code": 503,
-                     "key": "remote_addr"
-                 }
+}
+
+func TestHandler_Create(t *testing.T) {
+       tests := []struct {
+               caseDesc   string
+               giveInput  *entity.Consumer
+               giveCtx    context.Context
+               giveErr    error
+               wantErr    error
+               wantInput  *entity.Consumer
+               wantRet    interface{}
+               wantCalled bool
+       }{
+               {
+                       caseDesc: "normal",
+                       giveInput: &entity.Consumer{
+                               Username: "name",
+                               Plugins: map[string]interface{}{
+                                       "jwt-auth": map[string]interface{}{},
+                               },
+                       },
+                       giveCtx: context.WithValue(context.Background(), 
"test", "value"),
+                       wantInput: &entity.Consumer{
+                               BaseInfo: entity.BaseInfo{
+                                       ID: "name",
+                               },
+                               Username: "name",
+                               Plugins: map[string]interface{}{
+                                       "jwt-auth": map[string]interface{}{
+                                               "exp": 86400,
+                                       },
+                               },
+                       },
+                       wantRet:    nil,
+                       wantCalled: true,
                },
-               "desc": "test description"
-       }`
-       err = json.Unmarshal([]byte(reqBody), consumer2)
-       assert.Nil(t, err)
-       ctx.SetInput(consumer2)
-       _, err = handler.Create(ctx)
-       assert.Nil(t, err)
-
-       //sleep
-       time.Sleep(time.Duration(100) * time.Millisecond)
-
-       //get consumer
-       input := &GetInput{}
-       reqBody = `{"username": "jack"}`
-       err = json.Unmarshal([]byte(reqBody), input)
-       assert.Nil(t, err)
-       ctx.SetInput(input)
-       ret, err := handler.Get(ctx)
-       stored := ret.(*entity.Consumer)
-       assert.Nil(t, err)
-       assert.Equal(t, stored.ID, consumer.ID)
-       assert.Equal(t, stored.Username, consumer.Username)
-
-       //update consumer
-       consumer3 := &UpdateInput{}
-       consumer3.Username = "pony"
-       reqBody = `{
-               "username": "pony",
-               "plugins": {
-                 "limit-count": {
-                     "count": 2,
-                     "time_window": 60,
-                     "rejected_code": 503,
-                     "key": "remote_addr"
-                 }
+               {
+                       caseDesc: "store create failed",
+                       giveInput: &entity.Consumer{
+                               Username: "name",
+                               Plugins: map[string]interface{}{
+                                       "jwt-auth": map[string]interface{}{
+                                               "exp": 5000,
+                                       },
+                               },
+                       },
+                       giveErr: fmt.Errorf("create failed"),
+                       wantInput: &entity.Consumer{
+                               BaseInfo: entity.BaseInfo{
+                                       ID: "name",
+                               },
+                               Username: "name",
+                               Plugins: map[string]interface{}{
+                                       "jwt-auth": map[string]interface{}{
+                                               "exp": 5000,
+                                       },
+                               },
+                       },
+                       wantErr: fmt.Errorf("create failed"),
+                       wantRet: &data.SpecCodeResponse{
+                               StatusCode: http.StatusInternalServerError,
+                       },
+                       wantCalled: true,
                },
-               "desc": "test description2"
-       }`
-       err = json.Unmarshal([]byte(reqBody), consumer3)
-       assert.Nil(t, err)
-       ctx.SetInput(consumer3)
-       _, err = handler.Update(ctx)
-       assert.Nil(t, err)
-
-       //sleep
-       time.Sleep(time.Duration(100) * time.Millisecond)
-
-       //check update
-       input3 := &GetInput{}
-       reqBody = `{"username": "pony"}`
-       err = json.Unmarshal([]byte(reqBody), input3)
-       assert.Nil(t, err)
-       ctx.SetInput(input3)
-       ret3, err := handler.Get(ctx)
-       stored3 := ret3.(*entity.Consumer)
-       assert.Nil(t, err)
-       assert.Equal(t, stored3.Desc, "test description2") //consumer3.Desc)
-       assert.Equal(t, stored3.Username, consumer3.Username)
-
-       //list page 1
-       listInput := &ListInput{}
-       reqBody = `{"page_size": 1, "page": 1}`
-       err = json.Unmarshal([]byte(reqBody), listInput)
-       assert.Nil(t, err)
-       ctx.SetInput(listInput)
-       retPage1, err := handler.List(ctx)
-       assert.Nil(t, err)
-       dataPage1 := retPage1.(*store.ListOutput)
-       assert.Equal(t, len(dataPage1.Rows), 1)
-
-       //list page 2
-       listInput2 := &ListInput{}
-       reqBody = `{"page_size": 1, "page": 2}`
-       err = json.Unmarshal([]byte(reqBody), listInput2)
-       assert.Nil(t, err)
-       ctx.SetInput(listInput2)
-       retPage2, err := handler.List(ctx)
-       assert.Nil(t, err)
-       dataPage2 := retPage2.(*store.ListOutput)
-       assert.Equal(t, len(dataPage2.Rows), 1)
-
-       //list search match
-       listInput3 := &ListInput{}
-       reqBody = `{"page_size": 1, "page": 1, "username": "pony"}`
-       err = json.Unmarshal([]byte(reqBody), listInput3)
-       assert.Nil(t, err)
-       ctx.SetInput(listInput3)
-       retPage, err := handler.List(ctx)
-       assert.Nil(t, err)
-       dataPage := retPage.(*store.ListOutput)
-       assert.Equal(t, len(dataPage.Rows), 1)
-
-       //list search not match
-       listInput4 := &ListInput{}
-       reqBody = `{"page_size": 1, "page": 1, "username": "not-exists"}`
-       err = json.Unmarshal([]byte(reqBody), listInput4)
-       assert.Nil(t, err)
-       ctx.SetInput(listInput4)
-       retPage, err = handler.List(ctx)
-       assert.Nil(t, err)
-       dataPage = retPage.(*store.ListOutput)
-       assert.Equal(t, len(dataPage.Rows), 0)
-
-       //delete consumer
-       inputDel := &BatchDelete{}
-       reqBody = `{"usernames": "jack"}`
-       err = json.Unmarshal([]byte(reqBody), inputDel)
-       assert.Nil(t, err)
-       ctx.SetInput(inputDel)
-       _, err = handler.BatchDelete(ctx)
-       assert.Nil(t, err)
-
-       reqBody = `{"usernames": "pony"}`
-       err = json.Unmarshal([]byte(reqBody), inputDel)
-       assert.Nil(t, err)
-       ctx.SetInput(inputDel)
-       _, err = handler.BatchDelete(ctx)
-       assert.Nil(t, err)
-
-       //create consumer fail
-       consumer_fail := &entity.Consumer{}
-       reqBody = `{
-      "plugins": {
-          "limit-count": {
-              "count": 2,
-              "time_window": 60,
-              "rejected_code": 503,
-              "key": "remote_addr"
-          }
-      },
-    "desc": "test description"
-  }`
-       err = json.Unmarshal([]byte(reqBody), consumer_fail)
-       assert.Nil(t, err)
-       ctx.SetInput(consumer_fail)
-       _, err = handler.Create(ctx)
-       assert.NotNil(t, err)
-
-       //create consumer using Update
-       consumer6 := &UpdateInput{}
-       reqBody = `{
-      "username": "nnn",
-      "plugins": {
-          "limit-count": {
-              "count": 2,
-              "time_window": 60,
-              "rejected_code": 503,
-              "key": "remote_addr"
-          }
-      },
-    "desc": "test description"
-  }`
-       err = json.Unmarshal([]byte(reqBody), consumer6)
-       assert.Nil(t, err)
-       ctx.SetInput(consumer6)
-       _, err = handler.Update(ctx)
-       assert.Nil(t, err)
-
-       //sleep
-       time.Sleep(time.Duration(100) * time.Millisecond)
-
-       //delete consumer
-       reqBody = `{"usernames": "nnn"}`
-       err = json.Unmarshal([]byte(reqBody), inputDel)
-       assert.Nil(t, err)
-       ctx.SetInput(inputDel)
-       _, err = handler.BatchDelete(ctx)
-       assert.Nil(t, err)
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.caseDesc, func(t *testing.T) {
+                       methodCalled := true
+                       mStore := &store.MockInterface{}
+                       mStore.On("Create", mock.Anything, 
mock.Anything).Run(func(args mock.Arguments) {
+                               methodCalled = true
+                               assert.Equal(t, tc.giveCtx, args.Get(0))
+                               assert.Equal(t, tc.wantInput, args.Get(1))
+                       }).Return(tc.giveErr)
+
+                       h := Handler{consumerStore: mStore}
+                       ctx := droplet.NewContext()
+                       ctx.SetInput(tc.giveInput)
+                       ctx.SetContext(tc.giveCtx)
+                       ret, err := h.Create(ctx)
+                       assert.Equal(t, tc.wantCalled, methodCalled)
+                       assert.Equal(t, tc.wantRet, ret)
+                       assert.Equal(t, tc.wantErr, err)
+               })
+       }
+}
+
+func TestHandler_Update(t *testing.T) {
+       tests := []struct {
+               caseDesc   string
+               giveInput  *UpdateInput
+               giveCtx    context.Context
+               giveErr    error
+               wantErr    error
+               wantInput  *entity.Consumer
+               wantRet    interface{}
+               wantCalled bool
+       }{
+               {
+                       caseDesc: "normal",
+                       giveInput: &UpdateInput{
+                               Username: "name",
+                               Consumer: entity.Consumer{
+                                       Plugins: map[string]interface{}{
+                                               "jwt-auth": 
map[string]interface{}{
+                                                       "exp": 500,
+                                               },
+                                       },
+                               },
+                       },
+                       giveCtx: context.WithValue(context.Background(), 
"test", "value"),
+                       wantInput: &entity.Consumer{
+                               BaseInfo: entity.BaseInfo{
+                                       ID: "name",
+                               },
+                               Username: "name",
+                               Plugins: map[string]interface{}{
+                                       "jwt-auth": map[string]interface{}{
+                                               "exp": 500,
+                                       },
+                               },
+                       },
+                       wantRet:    nil,
+                       wantCalled: true,
+               },
+               {
+                       caseDesc: "store update failed",
+                       giveInput: &UpdateInput{
+                               Username: "name",
+                               Consumer: entity.Consumer{
+                                       Plugins: map[string]interface{}{
+                                               "jwt-auth": 
map[string]interface{}{},
+                                       },
+                               },
+                       },
+                       giveErr: fmt.Errorf("create failed"),
+                       wantInput: &entity.Consumer{
+                               BaseInfo: entity.BaseInfo{
+                                       ID: "name",
+                               },
+                               Username: "name",
+                               Plugins: map[string]interface{}{
+                                       "jwt-auth": map[string]interface{}{
+                                               "exp": 86400,
+                                       },
+                               },
+                       },
+                       wantErr: fmt.Errorf("create failed"),
+                       wantRet: &data.SpecCodeResponse{
+                               StatusCode: http.StatusInternalServerError,
+                       },
+                       wantCalled: true,
+               },
+       }
 
+       for _, tc := range tests {
+               t.Run(tc.caseDesc, func(t *testing.T) {
+                       methodCalled := true
+                       mStore := &store.MockInterface{}
+                       mStore.On("Update", mock.Anything, mock.Anything, 
mock.Anything).Run(func(args mock.Arguments) {
+                               methodCalled = true
+                               assert.Equal(t, tc.giveCtx, args.Get(0))
+                               assert.Equal(t, tc.wantInput, args.Get(1))
+                               assert.True(t, args.Bool(2))
+                       }).Return(tc.giveErr)
+
+                       h := Handler{consumerStore: mStore}
+                       ctx := droplet.NewContext()
+                       ctx.SetInput(tc.giveInput)
+                       ctx.SetContext(tc.giveCtx)
+                       ret, err := h.Update(ctx)
+                       assert.Equal(t, tc.wantCalled, methodCalled)
+                       assert.Equal(t, tc.wantRet, ret)
+                       assert.Equal(t, tc.wantErr, err)
+               })
+       }
+}
+
+func TestHandler_BatchDelete(t *testing.T) {
+       tests := []struct {
+               caseDesc  string
+               giveInput *BatchDeleteInput
+               giveCtx   context.Context
+               giveErr   error
+               wantErr   error
+               wantInput []string
+               wantRet   interface{}
+       }{
+               {
+                       caseDesc: "normal",
+                       giveInput: &BatchDeleteInput{
+                               UserNames: "user1,user2",
+                       },
+                       giveCtx: context.WithValue(context.Background(), 
"test", "value"),
+                       wantInput: []string{
+                               "user1",
+                               "user2",
+                       },
+               },
+               {
+                       caseDesc: "store delete failed",
+                       giveInput: &BatchDeleteInput{
+                               UserNames: "user1,user2",
+                       },
+                       giveCtx: context.WithValue(context.Background(), 
"test", "value"),
+                       giveErr: fmt.Errorf("delete failed"),
+                       wantInput: []string{
+                               "user1",
+                               "user2",
+                       },
+                       wantErr: fmt.Errorf("delete failed"),
+                       wantRet: &data.SpecCodeResponse{
+                               StatusCode: http.StatusInternalServerError,
+                       },
+               },
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.caseDesc, func(t *testing.T) {
+                       methodCalled := true
+                       mStore := &store.MockInterface{}
+                       mStore.On("BatchDelete", mock.Anything, mock.Anything, 
mock.Anything).Run(func(args mock.Arguments) {
+                               methodCalled = true
+                               assert.Equal(t, tc.giveCtx, args.Get(0))
+                               assert.Equal(t, tc.wantInput, args.Get(1))
+                       }).Return(tc.giveErr)
+
+                       h := Handler{consumerStore: mStore}
+                       ctx := droplet.NewContext()
+                       ctx.SetInput(tc.giveInput)
+                       ctx.SetContext(tc.giveCtx)
+                       ret, err := h.BatchDelete(ctx)
+                       assert.True(t, methodCalled)
+                       assert.Equal(t, tc.wantErr, err)
+                       assert.Equal(t, tc.wantRet, ret)
+               })
+       }
 }
diff --git a/api/log/zap.go b/api/log/zap.go
index 66379d8..cd36b53 100644
--- a/api/log/zap.go
+++ b/api/log/zap.go
@@ -27,10 +27,13 @@ import (
 
 var logger *zap.SugaredLogger
 
+// TODO: it is just for integration tests, we should call "InitLog" explicitly 
when remove all handler's integration tests
 func init() {
+       InitLogger()
+}
+func InitLogger() {
        logger = GetLogger(ErrorLog)
 }
-
 func GetLogger(logType Type) *zap.SugaredLogger {
        writeSyncer := fileWriter(logType)
        encoder := getEncoder(logType)

Reply via email to