This is an automated email from the ASF dual-hosted git repository.

klesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-devlake.git


The following commit(s) were added to refs/heads/main by this push:
     new 1a66e7961 feat(github): add support for refresh tokens and token 
management (#8667)
1a66e7961 is described below

commit 1a66e7961cf9ce53e2bbf89880da2c7694fa7c8e
Author: Yuvraj Singh Chauhan <[email protected]>
AuthorDate: Tue Jan 6 08:47:21 2026 +0530

    feat(github): add support for refresh tokens and token management (#8667)
    
    * feat(github): add support for refresh tokens and token management
    
    * Update backend/plugins/github/token/token_provider.go
    
    Co-authored-by: Copilot <[email protected]>
    
    * Update backend/plugins/github/token/token_provider.go
    
    Co-authored-by: Copilot <[email protected]>
    
    * Update backend/plugins/github/token/token_provider.go
    
    Co-authored-by: Copilot <[email protected]>
    
    * added documentation for roundtripper and fixed the infinite loop issue 
and added tests
    
    * fixed the illegal plugins/github/models import
    
    * Remove error conversion on token refresh failure and add token provider 
tests
    
    * renamed TestIsUserToServerToken to TestTokenTypeClassification for clarity
    
    * Update backend/plugins/github/token/token_provider.go
    
    Co-authored-by: Copilot <[email protected]>
    
    * Update backend/plugins/github/token/token_provider.go
    
    Co-authored-by: Copilot <[email protected]>
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 backend/helpers/pluginhelper/api/api_client.go     |   5 +
 backend/plugins/github/models/connection.go        |  17 +-
 backend/plugins/github/models/connection_test.go   |  11 ++
 .../20241120_add_refresh_token_fields.go           |  53 ++++++
 .../github/models/migrationscripts/register.go     |   1 +
 backend/plugins/github/tasks/api_client.go         |  19 +++
 backend/plugins/github/token/round_tripper.go      |  90 ++++++++++
 backend/plugins/github/token/round_tripper_test.go | 101 +++++++++++
 backend/plugins/github/token/token_provider.go     | 185 +++++++++++++++++++++
 .../plugins/github/token/token_provider_test.go    | 180 ++++++++++++++++++++
 10 files changed, 661 insertions(+), 1 deletion(-)

diff --git a/backend/helpers/pluginhelper/api/api_client.go 
b/backend/helpers/pluginhelper/api/api_client.go
index 1e7e57d44..b0cfccf49 100644
--- a/backend/helpers/pluginhelper/api/api_client.go
+++ b/backend/helpers/pluginhelper/api/api_client.go
@@ -299,6 +299,11 @@ func (apiClient *ApiClient) SetLogger(logger log.Logger) {
        apiClient.logger = logger
 }
 
+// GetClient returns the underlying http.Client
+func (apiClient *ApiClient) GetClient() *http.Client {
+       return apiClient.client
+}
+
 func (apiClient *ApiClient) logDebug(format string, a ...interface{}) {
        if apiClient.logger != nil {
                apiClient.logger.Debug(format, a...)
diff --git a/backend/plugins/github/models/connection.go 
b/backend/plugins/github/models/connection.go
index 6a8c06a37..4dba2f756 100644
--- a/backend/plugins/github/models/connection.go
+++ b/backend/plugins/github/models/connection.go
@@ -56,6 +56,21 @@ type GithubConn struct {
        helper.MultiAuth      `mapstructure:",squash"`
        GithubAccessToken     `mapstructure:",squash" authMethod:"AccessToken"`
        GithubAppKey          `mapstructure:",squash" authMethod:"AppKey"`
+       RefreshToken          string    `mapstructure:"refreshToken" 
json:"refreshToken" gorm:"type:text;serializer:encdec"`
+       TokenExpiresAt        time.Time `mapstructure:"tokenExpiresAt" 
json:"tokenExpiresAt"`
+       RefreshTokenExpiresAt time.Time `mapstructure:"refreshTokenExpiresAt" 
json:"refreshTokenExpiresAt"`
+}
+
+// UpdateToken updates the token and refresh token information
+func (conn *GithubConn) UpdateToken(newToken, newRefreshToken string, expiry, 
refreshExpiry time.Time) {
+       conn.Token = newToken
+       conn.RefreshToken = newRefreshToken
+       conn.TokenExpiresAt = expiry
+       conn.RefreshTokenExpiresAt = refreshExpiry
+
+       // Update the internal tokens slice used by SetupAuthentication
+       conn.tokens = []string{newToken}
+       conn.tokenIndex = 0
 }
 
 // PrepareApiClient splits Token to tokens for SetupAuthentication to utilize
@@ -249,7 +264,7 @@ func (conn *GithubConn) typeIs(token string) string {
        // total len is 40, {prefix}{showPrefix}{secret}{showSuffix}
        // fine-grained tokens
        // github_pat_{82_characters}
-       classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_", 
"ghr_"}
+       classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_", 
"ghr_", "ghu_"}
        classicalTokenFindGrainedPrefixes := []string{"github_pat_"}
        for _, prefix := range classicalTokenClassicalPrefixes {
                if strings.HasPrefix(token, prefix) {
diff --git a/backend/plugins/github/models/connection_test.go 
b/backend/plugins/github/models/connection_test.go
index 7b39cf0bd..41323b9b5 100644
--- a/backend/plugins/github/models/connection_test.go
+++ b/backend/plugins/github/models/connection_test.go
@@ -227,3 +227,14 @@ func TestGithubConnection_Sanitize(t *testing.T) {
                })
        }
 }
+
+func TestTokenTypeClassification(t *testing.T) {
+       conn := &GithubConn{}
+       assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghp_123"))
+       assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("gho_123"))
+       assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghu_123"))
+       assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghs_123"))
+       assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghr_123"))
+       assert.Equal(t, GithubTokenTypeFineGrained, 
conn.typeIs("github_pat_123"))
+       assert.Equal(t, GithubTokenTypeUnknown, conn.typeIs("some_other_token"))
+}
diff --git 
a/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go
 
b/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go
new file mode 100644
index 000000000..b2f826da3
--- /dev/null
+++ 
b/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go
@@ -0,0 +1,53 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package migrationscripts
+
+import (
+       "time"
+
+       "github.com/apache/incubator-devlake/core/context"
+       "github.com/apache/incubator-devlake/core/errors"
+       "github.com/apache/incubator-devlake/helpers/migrationhelper"
+)
+
+type githubConnection20241120 struct {
+       RefreshToken          string `gorm:"type:text;serializer:encdec"`
+       TokenExpiresAt        time.Time
+       RefreshTokenExpiresAt time.Time
+}
+
+func (githubConnection20241120) TableName() string {
+       return "_tool_github_connections"
+}
+
+type addRefreshTokenFields struct{}
+
+func (*addRefreshTokenFields) Up(basicRes context.BasicRes) errors.Error {
+       return migrationhelper.AutoMigrateTables(
+               basicRes,
+               &githubConnection20241120{},
+       )
+}
+
+func (*addRefreshTokenFields) Version() uint64 {
+       return 20241120000001
+}
+
+func (*addRefreshTokenFields) Name() string {
+       return "add refresh token fields to github_connections"
+}
diff --git a/backend/plugins/github/models/migrationscripts/register.go 
b/backend/plugins/github/models/migrationscripts/register.go
index b8a0722eb..74f9d712b 100644
--- a/backend/plugins/github/models/migrationscripts/register.go
+++ b/backend/plugins/github/models/migrationscripts/register.go
@@ -55,5 +55,6 @@ func All() []plugin.MigrationScript {
                new(addIsDraftToPr),
                new(changeIssueComponentType),
                new(addIndexToGithubJobs),
+               new(addRefreshTokenFields),
        }
 }
diff --git a/backend/plugins/github/tasks/api_client.go 
b/backend/plugins/github/tasks/api_client.go
index c9bfa852c..268af8ece 100644
--- a/backend/plugins/github/tasks/api_client.go
+++ b/backend/plugins/github/tasks/api_client.go
@@ -26,6 +26,7 @@ import (
        "github.com/apache/incubator-devlake/core/plugin"
        "github.com/apache/incubator-devlake/helpers/pluginhelper/api"
        "github.com/apache/incubator-devlake/plugins/github/models"
+       "github.com/apache/incubator-devlake/plugins/github/token"
 )
 
 func CreateApiClient(taskCtx plugin.TaskContext, connection 
*models.GithubConnection) (*api.ApiAsyncClient, errors.Error) {
@@ -34,6 +35,24 @@ func CreateApiClient(taskCtx plugin.TaskContext, connection 
*models.GithubConnec
                return nil, err
        }
 
+       // Inject TokenProvider if refresh token is present
+       if connection.RefreshToken != "" {
+               logger := taskCtx.GetLogger()
+               db := taskCtx.GetDal()
+
+               // Create TokenProvider
+               tp := token.NewTokenProvider(connection, db, 
apiClient.GetClient(), logger)
+
+               // Wrap the transport
+               baseTransport := apiClient.GetClient().Transport
+               if baseTransport == nil {
+                       baseTransport = http.DefaultTransport
+               }
+
+               rt := token.NewRefreshRoundTripper(baseTransport, tp)
+               apiClient.GetClient().Transport = rt
+       }
+
        // create rate limit calculator
        rateLimiter := &api.ApiRateLimitCalculator{
                UserRateLimitPerHour: connection.RateLimitPerHour,
diff --git a/backend/plugins/github/token/round_tripper.go 
b/backend/plugins/github/token/round_tripper.go
new file mode 100644
index 000000000..45ba3e9a7
--- /dev/null
+++ b/backend/plugins/github/token/round_tripper.go
@@ -0,0 +1,90 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+       "net/http"
+)
+
+// RefreshRoundTripper is an HTTP transport middleware that automatically 
manages OAuth token refreshes.
+// It wraps an underlying http.RoundTripper and provides token refresh on auth 
failures.
+// On 401's the round tripper will:
+// - Force a refresh of the OAuth token via the TokenProvider
+// - Retry the original request with the new token
+type RefreshRoundTripper struct {
+       base          http.RoundTripper
+       tokenProvider *TokenProvider
+}
+
+func NewRefreshRoundTripper(base http.RoundTripper, tp *TokenProvider) 
*RefreshRoundTripper {
+       return &RefreshRoundTripper{
+               base:          base,
+               tokenProvider: tp,
+       }
+}
+
+// RoundTrip implements the http.RoundTripper interface and handles automatic 
token refresh on 401 responses.
+// It clones the request, adds the Authorization header, and retries once with 
a refreshed token if needed.
+func (rt *RefreshRoundTripper) RoundTrip(req *http.Request) (*http.Response, 
error) {
+       return rt.roundTripWithRetry(req, false)
+}
+
+// roundTripWithRetry performs the actual request with retry on 401.
+// The refreshAttempted parameter tracks whether a refresh has already been 
tried for this request
+// to prevent infinite retry loops if token refresh itself fails.
+func (rt *RefreshRoundTripper) roundTripWithRetry(req *http.Request, 
refreshAttempted bool) (*http.Response, error) {
+       // Get token
+       token, err := rt.tokenProvider.GetToken()
+       if err != nil {
+               return nil, err
+       }
+
+       // Clone request before modifying
+       reqClone := req.Clone(req.Context())
+       reqClone.Header.Set("Authorization", "Bearer "+token)
+
+       // Execute request
+       resp, reqErr := rt.base.RoundTrip(reqClone)
+       if reqErr != nil {
+               return nil, reqErr
+       }
+
+       // Reactive refresh on 401
+       if resp.StatusCode == http.StatusUnauthorized && !refreshAttempted {
+               // Close previous response body
+               resp.Body.Close()
+
+               // Force refresh
+               if err := rt.tokenProvider.ForceRefresh(token); err != nil {
+                       return nil, err
+               }
+
+               // Get new token
+               newToken, err := rt.tokenProvider.GetToken()
+               if err != nil {
+                       return nil, err
+               }
+
+               // Retry request with new token
+               reqRetry := req.Clone(req.Context())
+               reqRetry.Header.Set("Authorization", "Bearer "+newToken)
+               return rt.roundTripWithRetry(reqRetry, true)
+       }
+
+       return resp, nil
+}
diff --git a/backend/plugins/github/token/round_tripper_test.go 
b/backend/plugins/github/token/round_tripper_test.go
new file mode 100644
index 000000000..6767d8ccf
--- /dev/null
+++ b/backend/plugins/github/token/round_tripper_test.go
@@ -0,0 +1,101 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+       "bytes"
+       "io"
+       "net/http"
+       "testing"
+       "time"
+
+       "github.com/apache/incubator-devlake/helpers/pluginhelper/api"
+       "github.com/apache/incubator-devlake/impls/logruslog"
+       "github.com/apache/incubator-devlake/plugins/github/models"
+       "github.com/sirupsen/logrus"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/mock"
+)
+
+func TestRoundTripper401Refresh(t *testing.T) {
+       mockRT := new(MockRoundTripper)
+       client := &http.Client{Transport: mockRT}
+
+       conn := &models.GithubConnection{
+               GithubConn: models.GithubConn{
+                       RefreshToken: "refresh_token",
+                       GithubAccessToken: models.GithubAccessToken{
+                               AccessToken: api.AccessToken{
+                                       Token: "old_token",
+                               },
+                       },
+                       TokenExpiresAt: time.Now().Add(10 * time.Minute), // 
Not expired
+                       GithubAppKey: models.GithubAppKey{
+                               AppKey: api.AppKey{
+                                       AppId:     "123",
+                                       SecretKey: "secret",
+                               },
+                       },
+               },
+       }
+
+       logger, _ := logruslog.NewDefaultLogger(logrus.New())
+       tp := NewTokenProvider(conn, nil, client, logger)
+       rt := NewRefreshRoundTripper(mockRT, tp)
+
+       // Request
+       req, _ := http.NewRequest("GET", "https://api.github.com/user";, nil)
+
+       // 1. First call returns 401
+       resp401 := &http.Response{
+               StatusCode: 401,
+               Body:       io.NopCloser(bytes.NewBufferString("Unauthorized")),
+       }
+       mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
+               return r.Header.Get("Authorization") == "Bearer old_token" && 
r.URL.String() == "https://api.github.com/user";
+       })).Return(resp401, nil).Once()
+
+       // 2. Refresh call (triggered by 401)
+       respRefresh := &http.Response{
+               StatusCode: 200,
+               Body:       
io.NopCloser(bytes.NewBufferString(`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`)),
+       }
+       // The refresh call uses the same client, so it goes through mockRT too!
+       mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
+               return r.URL.String() == 
"https://github.com/login/oauth/access_token";
+       })).Return(respRefresh, nil).Once()
+
+       // 3. Retry call with new token
+       resp200 := &http.Response{
+               StatusCode: 200,
+               Body:       io.NopCloser(bytes.NewBufferString("Success")),
+       }
+       mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
+               return r.Header.Get("Authorization") == "Bearer new_token" && 
r.URL.String() == "https://api.github.com/user";
+       })).Return(resp200, nil).Once()
+
+       // Execute
+       resp, err := rt.RoundTrip(req)
+       assert.NoError(t, err)
+       assert.Equal(t, 200, resp.StatusCode)
+
+       body, _ := io.ReadAll(resp.Body)
+       assert.Equal(t, "Success", string(body))
+
+       mockRT.AssertExpectations(t)
+}
diff --git a/backend/plugins/github/token/token_provider.go 
b/backend/plugins/github/token/token_provider.go
new file mode 100644
index 000000000..ba9941cd4
--- /dev/null
+++ b/backend/plugins/github/token/token_provider.go
@@ -0,0 +1,185 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+       "bytes"
+       "context"
+       "encoding/json"
+       "fmt"
+       "io"
+       "net/http"
+       "os"
+       "strconv"
+       "sync"
+       "time"
+
+       "github.com/apache/incubator-devlake/core/dal"
+       "github.com/apache/incubator-devlake/core/errors"
+       "github.com/apache/incubator-devlake/core/log"
+       "github.com/apache/incubator-devlake/plugins/github/models"
+)
+
+const (
+       DefaultRefreshBuffer = 5 * time.Minute
+)
+
+type TokenProvider struct {
+       conn       *models.GithubConnection
+       dal        dal.Dal
+       httpClient *http.Client
+       logger     log.Logger
+       mu         sync.Mutex
+       refreshURL string
+}
+
+// NewTokenProvider creates a TokenProvider for the given GitHub connection 
using
+// the provided DAL, HTTP client, and logger, and returns a pointer to it.
+func NewTokenProvider(conn *models.GithubConnection, d dal.Dal, client 
*http.Client, logger log.Logger) *TokenProvider {
+       return &TokenProvider{
+               conn:       conn,
+               dal:        d,
+               httpClient: client,
+               logger:     logger,
+               refreshURL: "https://github.com/login/oauth/access_token";,
+       }
+}
+
+func (tp *TokenProvider) GetToken() (string, errors.Error) {
+       tp.mu.Lock()
+       defer tp.mu.Unlock()
+
+       if tp.needsRefresh() {
+               if err := tp.refreshToken(); err != nil {
+                       return "", err
+               }
+       }
+       return tp.conn.Token, nil
+}
+
+func (tp *TokenProvider) needsRefresh() bool {
+       if tp.conn.RefreshToken == "" {
+               return false
+       }
+
+       buffer := DefaultRefreshBuffer
+       if envBuffer := os.Getenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES"); 
envBuffer != "" {
+               if val, err := strconv.Atoi(envBuffer); err == nil {
+                       buffer = time.Duration(val) * time.Minute
+               }
+       }
+
+       return time.Now().Add(buffer).After(tp.conn.TokenExpiresAt)
+}
+
+func (tp *TokenProvider) refreshToken() errors.Error {
+       tp.logger.Info("Refreshing GitHub token for connection %d", tp.conn.ID)
+
+       data := map[string]string{
+               "refresh_token": tp.conn.RefreshToken,
+               "grant_type":    "refresh_token",
+               "client_id":     tp.conn.AppId,
+               "client_secret": tp.conn.SecretKey,
+       }
+       jsonData, _ := json.Marshal(data)
+
+       ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+       defer cancel()
+
+       req, err := http.NewRequestWithContext(ctx, "POST", tp.refreshURL, 
bytes.NewBuffer(jsonData))
+       if err != nil {
+               return errors.Convert(err)
+       }
+       req.Header.Set("Content-Type", "application/json")
+       req.Header.Set("Accept", "application/json")
+
+       resp, err := tp.httpClient.Do(req)
+       if err != nil {
+               return errors.Convert(err)
+       }
+       defer resp.Body.Close()
+
+       body, err := io.ReadAll(resp.Body)
+       if err != nil {
+               return errors.Convert(err)
+       }
+
+       if resp.StatusCode != http.StatusOK {
+               // Log the response body to aid in debugging token refresh 
failures.
+               if tp.logger != nil {
+                       tp.logger.Error(nil, "failed to refresh token from 
GitHub, status=%d, body=%s", resp.StatusCode, string(body))
+               }
+               return errors.Default.New(fmt.Sprintf("failed to refresh token: 
%d, body: %s", resp.StatusCode, string(body)))
+       }
+       var result struct {
+               AccessToken           string `json:"access_token"`
+               RefreshToken          string `json:"refresh_token"`
+               ExpiresIn             int    `json:"expires_in"`
+               RefreshTokenExpiresIn int    `json:"refresh_token_expires_in"`
+       }
+       if err := json.Unmarshal(body, &result); err != nil {
+               return errors.Convert(err)
+       }
+
+       if result.AccessToken == "" {
+               bodyStr := string(body)
+               const maxBodySnippet = 512
+               if len(bodyStr) > maxBodySnippet {
+                       bodyStr = bodyStr[:maxBodySnippet] + "…"
+               }
+               return errors.Default.New(fmt.Sprintf("empty access token 
returned; response body: %s", bodyStr))
+       }
+
+       tp.conn.UpdateToken(
+               result.AccessToken,
+               result.RefreshToken,
+               time.Now().Add(time.Duration(result.ExpiresIn)*time.Second),
+               
time.Now().Add(time.Duration(result.RefreshTokenExpiresIn)*time.Second),
+       )
+
+       if tp.dal != nil {
+               err := tp.dal.UpdateColumns(tp.conn, []dal.DalSet{
+                       {ColumnName: "token", Value: tp.conn.Token},
+                       {ColumnName: "refresh_token", Value: 
tp.conn.RefreshToken},
+                       {ColumnName: "token_expires_at", Value: 
tp.conn.TokenExpiresAt},
+                       {ColumnName: "refresh_token_expires_at", Value: 
tp.conn.RefreshTokenExpiresAt},
+               })
+               if err != nil {
+                       tp.logger.Warn(err, "failed to persist refreshed token")
+               }
+       }
+
+       return nil
+}
+
+// ForceRefresh refreshes the access token if the current token is still equal 
to oldToken.
+// The oldToken parameter should be the token value observed by the caller 
when it determined
+// that a refresh might be needed; if the token has changed since then, 
another goroutine has
+// already refreshed it and this method returns without performing a redundant 
refresh.
+func (tp *TokenProvider) ForceRefresh(oldToken string) errors.Error {
+       tp.mu.Lock()
+       defer tp.mu.Unlock()
+
+       // If the token has changed since the request was made, it means 
another thread
+       // has already refreshed it.
+       if tp.conn.Token != oldToken {
+               return nil
+       }
+
+       return tp.refreshToken()
+}
diff --git a/backend/plugins/github/token/token_provider_test.go 
b/backend/plugins/github/token/token_provider_test.go
new file mode 100644
index 000000000..1c296376a
--- /dev/null
+++ b/backend/plugins/github/token/token_provider_test.go
@@ -0,0 +1,180 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+       "bytes"
+       "io"
+       "net/http"
+       "os"
+       "sync"
+       "testing"
+       "time"
+
+       "github.com/apache/incubator-devlake/core/errors"
+       "github.com/apache/incubator-devlake/helpers/pluginhelper/api"
+       "github.com/apache/incubator-devlake/impls/logruslog"
+       mockdal "github.com/apache/incubator-devlake/mocks/core/dal"
+       "github.com/apache/incubator-devlake/plugins/github/models"
+       "github.com/sirupsen/logrus"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/mock"
+)
+
+type MockRoundTripper struct {
+       mock.Mock
+}
+
+func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, 
error) {
+       args := m.Called(req)
+       return args.Get(0).(*http.Response), args.Error(1)
+}
+
+func TestNeedsRefresh(t *testing.T) {
+       tp := &TokenProvider{
+               conn: &models.GithubConnection{
+                       GithubConn: models.GithubConn{
+                               RefreshToken: "refresh_token",
+                       },
+               },
+       }
+
+       // Not expired, outside buffer
+       tp.conn.TokenExpiresAt = time.Now().Add(10 * time.Minute)
+       assert.False(t, tp.needsRefresh())
+
+       // Inside buffer
+       tp.conn.TokenExpiresAt = time.Now().Add(1 * time.Minute)
+       assert.True(t, tp.needsRefresh())
+
+       // Expired
+       tp.conn.TokenExpiresAt = time.Now().Add(-1 * time.Minute)
+       assert.True(t, tp.needsRefresh())
+
+       // No refresh token
+       tp.conn.RefreshToken = ""
+       assert.False(t, tp.needsRefresh())
+}
+
+func TestTokenProviderConcurrency(t *testing.T) {
+       mockRT := new(MockRoundTripper)
+       client := &http.Client{Transport: mockRT}
+
+       conn := &models.GithubConnection{
+               GithubConn: models.GithubConn{
+                       RefreshToken:   "refresh_token",
+                       TokenExpiresAt: time.Now().Add(-1 * time.Minute), // 
Expired
+                       GithubAppKey: models.GithubAppKey{
+                               AppKey: api.AppKey{
+                                       AppId:     "123",
+                                       SecretKey: "secret",
+                               },
+                       },
+               },
+       }
+
+       logger, _ := logruslog.NewDefaultLogger(logrus.New())
+       tp := NewTokenProvider(conn, nil, client, logger)
+
+       // Mock response for refresh
+       respBody := 
`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`
+       resp := &http.Response{
+               StatusCode: 200,
+               Body:       io.NopCloser(bytes.NewBufferString(respBody)),
+       }
+
+       // Expect exactly one call
+       mockRT.On("RoundTrip", mock.Anything).Return(resp, nil).Once()
+
+       var wg sync.WaitGroup
+       for i := 0; i < 10; i++ {
+               wg.Add(1)
+               go func() {
+                       defer wg.Done()
+                       token, err := tp.GetToken()
+                       assert.NoError(t, err)
+                       assert.Equal(t, "new_token", token)
+               }()
+       }
+       wg.Wait()
+
+       mockRT.AssertExpectations(t)
+}
+
+func TestConfigurableBuffer(t *testing.T) {
+       os.Setenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES", "10")
+       defer os.Unsetenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES")
+
+       tp := &TokenProvider{
+               conn: &models.GithubConnection{
+                       GithubConn: models.GithubConn{
+                               RefreshToken: "refresh_token",
+                       },
+               },
+       }
+
+       // 9 minutes remaining (inside 10m buffer)
+       tp.conn.TokenExpiresAt = time.Now().Add(9 * time.Minute)
+       assert.True(t, tp.needsRefresh())
+
+       // 11 minutes remaining (outside 10m buffer)
+       tp.conn.TokenExpiresAt = time.Now().Add(11 * time.Minute)
+       assert.False(t, tp.needsRefresh())
+}
+
+func TestPersistenceFailure(t *testing.T) {
+       mockRT := new(MockRoundTripper)
+       client := &http.Client{Transport: mockRT}
+       mockDal := new(mockdal.Dal)
+
+       conn := &models.GithubConnection{
+               GithubConn: models.GithubConn{
+                       RefreshToken: "refresh_token",
+                       GithubAccessToken: models.GithubAccessToken{
+                               AccessToken: api.AccessToken{
+                                       Token: "old_token",
+                               },
+                       },
+                       GithubAppKey: models.GithubAppKey{
+                               AppKey: api.AppKey{
+                                       AppId:     "123",
+                                       SecretKey: "secret",
+                               },
+                       },
+               },
+       }
+
+       logger, _ := logruslog.NewDefaultLogger(logrus.New())
+       tp := NewTokenProvider(conn, mockDal, client, logger)
+
+       // Mock response for refresh
+       respBody := 
`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`
+       resp := &http.Response{
+               StatusCode: 200,
+               Body:       io.NopCloser(bytes.NewBufferString(respBody)),
+       }
+       mockRT.On("RoundTrip", mock.Anything).Return(resp, nil).Once()
+
+       // Mock DAL failure
+       mockDal.On("UpdateColumns", mock.Anything, mock.Anything, 
mock.AnythingOfType("[]dal.Clause")).Return(errors.Default.New("db error"))
+       err := tp.ForceRefresh("old_token")
+       assert.NoError(t, err) // Should not return error even if persistence 
fails
+
+       mockRT.AssertExpectations(t)
+       mockDal.AssertExpectations(t)
+}

Reply via email to