This is an automated email from the ASF dual-hosted git repository.
klesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-devlake.git
The following commit(s) were added to refs/heads/main by this push:
new 1a66e7961 feat(github): add support for refresh tokens and token
management (#8667)
1a66e7961 is described below
commit 1a66e7961cf9ce53e2bbf89880da2c7694fa7c8e
Author: Yuvraj Singh Chauhan <[email protected]>
AuthorDate: Tue Jan 6 08:47:21 2026 +0530
feat(github): add support for refresh tokens and token management (#8667)
* feat(github): add support for refresh tokens and token management
* Update backend/plugins/github/token/token_provider.go
Co-authored-by: Copilot <[email protected]>
* Update backend/plugins/github/token/token_provider.go
Co-authored-by: Copilot <[email protected]>
* Update backend/plugins/github/token/token_provider.go
Co-authored-by: Copilot <[email protected]>
* added documentation for roundtripper and fixed the infinite loop issue
and added tests
* fixed the illegal plugins/github/models import
* Remove error conversion on token refresh failure and add token provider
tests
* renamed TestIsUserToServerToken to TestTokenTypeClassification for clarity
* Update backend/plugins/github/token/token_provider.go
Co-authored-by: Copilot <[email protected]>
* Update backend/plugins/github/token/token_provider.go
Co-authored-by: Copilot <[email protected]>
---------
Co-authored-by: Copilot <[email protected]>
---
backend/helpers/pluginhelper/api/api_client.go | 5 +
backend/plugins/github/models/connection.go | 17 +-
backend/plugins/github/models/connection_test.go | 11 ++
.../20241120_add_refresh_token_fields.go | 53 ++++++
.../github/models/migrationscripts/register.go | 1 +
backend/plugins/github/tasks/api_client.go | 19 +++
backend/plugins/github/token/round_tripper.go | 90 ++++++++++
backend/plugins/github/token/round_tripper_test.go | 101 +++++++++++
backend/plugins/github/token/token_provider.go | 185 +++++++++++++++++++++
.../plugins/github/token/token_provider_test.go | 180 ++++++++++++++++++++
10 files changed, 661 insertions(+), 1 deletion(-)
diff --git a/backend/helpers/pluginhelper/api/api_client.go
b/backend/helpers/pluginhelper/api/api_client.go
index 1e7e57d44..b0cfccf49 100644
--- a/backend/helpers/pluginhelper/api/api_client.go
+++ b/backend/helpers/pluginhelper/api/api_client.go
@@ -299,6 +299,11 @@ func (apiClient *ApiClient) SetLogger(logger log.Logger) {
apiClient.logger = logger
}
+// GetClient returns the underlying http.Client
+func (apiClient *ApiClient) GetClient() *http.Client {
+ return apiClient.client
+}
+
func (apiClient *ApiClient) logDebug(format string, a ...interface{}) {
if apiClient.logger != nil {
apiClient.logger.Debug(format, a...)
diff --git a/backend/plugins/github/models/connection.go
b/backend/plugins/github/models/connection.go
index 6a8c06a37..4dba2f756 100644
--- a/backend/plugins/github/models/connection.go
+++ b/backend/plugins/github/models/connection.go
@@ -56,6 +56,21 @@ type GithubConn struct {
helper.MultiAuth `mapstructure:",squash"`
GithubAccessToken `mapstructure:",squash" authMethod:"AccessToken"`
GithubAppKey `mapstructure:",squash" authMethod:"AppKey"`
+ RefreshToken string `mapstructure:"refreshToken"
json:"refreshToken" gorm:"type:text;serializer:encdec"`
+ TokenExpiresAt time.Time `mapstructure:"tokenExpiresAt"
json:"tokenExpiresAt"`
+ RefreshTokenExpiresAt time.Time `mapstructure:"refreshTokenExpiresAt"
json:"refreshTokenExpiresAt"`
+}
+
+// UpdateToken updates the token and refresh token information
+func (conn *GithubConn) UpdateToken(newToken, newRefreshToken string, expiry,
refreshExpiry time.Time) {
+ conn.Token = newToken
+ conn.RefreshToken = newRefreshToken
+ conn.TokenExpiresAt = expiry
+ conn.RefreshTokenExpiresAt = refreshExpiry
+
+ // Update the internal tokens slice used by SetupAuthentication
+ conn.tokens = []string{newToken}
+ conn.tokenIndex = 0
}
// PrepareApiClient splits Token to tokens for SetupAuthentication to utilize
@@ -249,7 +264,7 @@ func (conn *GithubConn) typeIs(token string) string {
// total len is 40, {prefix}{showPrefix}{secret}{showSuffix}
// fine-grained tokens
// github_pat_{82_characters}
- classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_",
"ghr_"}
+ classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_",
"ghr_", "ghu_"}
classicalTokenFindGrainedPrefixes := []string{"github_pat_"}
for _, prefix := range classicalTokenClassicalPrefixes {
if strings.HasPrefix(token, prefix) {
diff --git a/backend/plugins/github/models/connection_test.go
b/backend/plugins/github/models/connection_test.go
index 7b39cf0bd..41323b9b5 100644
--- a/backend/plugins/github/models/connection_test.go
+++ b/backend/plugins/github/models/connection_test.go
@@ -227,3 +227,14 @@ func TestGithubConnection_Sanitize(t *testing.T) {
})
}
}
+
+func TestTokenTypeClassification(t *testing.T) {
+ conn := &GithubConn{}
+ assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghp_123"))
+ assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("gho_123"))
+ assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghu_123"))
+ assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghs_123"))
+ assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghr_123"))
+ assert.Equal(t, GithubTokenTypeFineGrained,
conn.typeIs("github_pat_123"))
+ assert.Equal(t, GithubTokenTypeUnknown, conn.typeIs("some_other_token"))
+}
diff --git
a/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go
b/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go
new file mode 100644
index 000000000..b2f826da3
--- /dev/null
+++
b/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go
@@ -0,0 +1,53 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package migrationscripts
+
+import (
+ "time"
+
+ "github.com/apache/incubator-devlake/core/context"
+ "github.com/apache/incubator-devlake/core/errors"
+ "github.com/apache/incubator-devlake/helpers/migrationhelper"
+)
+
+type githubConnection20241120 struct {
+ RefreshToken string `gorm:"type:text;serializer:encdec"`
+ TokenExpiresAt time.Time
+ RefreshTokenExpiresAt time.Time
+}
+
+func (githubConnection20241120) TableName() string {
+ return "_tool_github_connections"
+}
+
+type addRefreshTokenFields struct{}
+
+func (*addRefreshTokenFields) Up(basicRes context.BasicRes) errors.Error {
+ return migrationhelper.AutoMigrateTables(
+ basicRes,
+ &githubConnection20241120{},
+ )
+}
+
+func (*addRefreshTokenFields) Version() uint64 {
+ return 20241120000001
+}
+
+func (*addRefreshTokenFields) Name() string {
+ return "add refresh token fields to github_connections"
+}
diff --git a/backend/plugins/github/models/migrationscripts/register.go
b/backend/plugins/github/models/migrationscripts/register.go
index b8a0722eb..74f9d712b 100644
--- a/backend/plugins/github/models/migrationscripts/register.go
+++ b/backend/plugins/github/models/migrationscripts/register.go
@@ -55,5 +55,6 @@ func All() []plugin.MigrationScript {
new(addIsDraftToPr),
new(changeIssueComponentType),
new(addIndexToGithubJobs),
+ new(addRefreshTokenFields),
}
}
diff --git a/backend/plugins/github/tasks/api_client.go
b/backend/plugins/github/tasks/api_client.go
index c9bfa852c..268af8ece 100644
--- a/backend/plugins/github/tasks/api_client.go
+++ b/backend/plugins/github/tasks/api_client.go
@@ -26,6 +26,7 @@ import (
"github.com/apache/incubator-devlake/core/plugin"
"github.com/apache/incubator-devlake/helpers/pluginhelper/api"
"github.com/apache/incubator-devlake/plugins/github/models"
+ "github.com/apache/incubator-devlake/plugins/github/token"
)
func CreateApiClient(taskCtx plugin.TaskContext, connection
*models.GithubConnection) (*api.ApiAsyncClient, errors.Error) {
@@ -34,6 +35,24 @@ func CreateApiClient(taskCtx plugin.TaskContext, connection
*models.GithubConnec
return nil, err
}
+ // Inject TokenProvider if refresh token is present
+ if connection.RefreshToken != "" {
+ logger := taskCtx.GetLogger()
+ db := taskCtx.GetDal()
+
+ // Create TokenProvider
+ tp := token.NewTokenProvider(connection, db,
apiClient.GetClient(), logger)
+
+ // Wrap the transport
+ baseTransport := apiClient.GetClient().Transport
+ if baseTransport == nil {
+ baseTransport = http.DefaultTransport
+ }
+
+ rt := token.NewRefreshRoundTripper(baseTransport, tp)
+ apiClient.GetClient().Transport = rt
+ }
+
// create rate limit calculator
rateLimiter := &api.ApiRateLimitCalculator{
UserRateLimitPerHour: connection.RateLimitPerHour,
diff --git a/backend/plugins/github/token/round_tripper.go
b/backend/plugins/github/token/round_tripper.go
new file mode 100644
index 000000000..45ba3e9a7
--- /dev/null
+++ b/backend/plugins/github/token/round_tripper.go
@@ -0,0 +1,90 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+ "net/http"
+)
+
+// RefreshRoundTripper is an HTTP transport middleware that automatically
manages OAuth token refreshes.
+// It wraps an underlying http.RoundTripper and provides token refresh on auth
failures.
+// On 401's the round tripper will:
+// - Force a refresh of the OAuth token via the TokenProvider
+// - Retry the original request with the new token
+type RefreshRoundTripper struct {
+ base http.RoundTripper
+ tokenProvider *TokenProvider
+}
+
+func NewRefreshRoundTripper(base http.RoundTripper, tp *TokenProvider)
*RefreshRoundTripper {
+ return &RefreshRoundTripper{
+ base: base,
+ tokenProvider: tp,
+ }
+}
+
+// RoundTrip implements the http.RoundTripper interface and handles automatic
token refresh on 401 responses.
+// It clones the request, adds the Authorization header, and retries once with
a refreshed token if needed.
+func (rt *RefreshRoundTripper) RoundTrip(req *http.Request) (*http.Response,
error) {
+ return rt.roundTripWithRetry(req, false)
+}
+
+// roundTripWithRetry performs the actual request with retry on 401.
+// The refreshAttempted parameter tracks whether a refresh has already been
tried for this request
+// to prevent infinite retry loops if token refresh itself fails.
+func (rt *RefreshRoundTripper) roundTripWithRetry(req *http.Request,
refreshAttempted bool) (*http.Response, error) {
+ // Get token
+ token, err := rt.tokenProvider.GetToken()
+ if err != nil {
+ return nil, err
+ }
+
+ // Clone request before modifying
+ reqClone := req.Clone(req.Context())
+ reqClone.Header.Set("Authorization", "Bearer "+token)
+
+ // Execute request
+ resp, reqErr := rt.base.RoundTrip(reqClone)
+ if reqErr != nil {
+ return nil, reqErr
+ }
+
+ // Reactive refresh on 401
+ if resp.StatusCode == http.StatusUnauthorized && !refreshAttempted {
+ // Close previous response body
+ resp.Body.Close()
+
+ // Force refresh
+ if err := rt.tokenProvider.ForceRefresh(token); err != nil {
+ return nil, err
+ }
+
+ // Get new token
+ newToken, err := rt.tokenProvider.GetToken()
+ if err != nil {
+ return nil, err
+ }
+
+ // Retry request with new token
+ reqRetry := req.Clone(req.Context())
+ reqRetry.Header.Set("Authorization", "Bearer "+newToken)
+ return rt.roundTripWithRetry(reqRetry, true)
+ }
+
+ return resp, nil
+}
diff --git a/backend/plugins/github/token/round_tripper_test.go
b/backend/plugins/github/token/round_tripper_test.go
new file mode 100644
index 000000000..6767d8ccf
--- /dev/null
+++ b/backend/plugins/github/token/round_tripper_test.go
@@ -0,0 +1,101 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/apache/incubator-devlake/helpers/pluginhelper/api"
+ "github.com/apache/incubator-devlake/impls/logruslog"
+ "github.com/apache/incubator-devlake/plugins/github/models"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+func TestRoundTripper401Refresh(t *testing.T) {
+ mockRT := new(MockRoundTripper)
+ client := &http.Client{Transport: mockRT}
+
+ conn := &models.GithubConnection{
+ GithubConn: models.GithubConn{
+ RefreshToken: "refresh_token",
+ GithubAccessToken: models.GithubAccessToken{
+ AccessToken: api.AccessToken{
+ Token: "old_token",
+ },
+ },
+ TokenExpiresAt: time.Now().Add(10 * time.Minute), //
Not expired
+ GithubAppKey: models.GithubAppKey{
+ AppKey: api.AppKey{
+ AppId: "123",
+ SecretKey: "secret",
+ },
+ },
+ },
+ }
+
+ logger, _ := logruslog.NewDefaultLogger(logrus.New())
+ tp := NewTokenProvider(conn, nil, client, logger)
+ rt := NewRefreshRoundTripper(mockRT, tp)
+
+ // Request
+ req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
+
+ // 1. First call returns 401
+ resp401 := &http.Response{
+ StatusCode: 401,
+ Body: io.NopCloser(bytes.NewBufferString("Unauthorized")),
+ }
+ mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
+ return r.Header.Get("Authorization") == "Bearer old_token" &&
r.URL.String() == "https://api.github.com/user"
+ })).Return(resp401, nil).Once()
+
+ // 2. Refresh call (triggered by 401)
+ respRefresh := &http.Response{
+ StatusCode: 200,
+ Body:
io.NopCloser(bytes.NewBufferString(`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`)),
+ }
+ // The refresh call uses the same client, so it goes through mockRT too!
+ mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
+ return r.URL.String() ==
"https://github.com/login/oauth/access_token"
+ })).Return(respRefresh, nil).Once()
+
+ // 3. Retry call with new token
+ resp200 := &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(bytes.NewBufferString("Success")),
+ }
+ mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
+ return r.Header.Get("Authorization") == "Bearer new_token" &&
r.URL.String() == "https://api.github.com/user"
+ })).Return(resp200, nil).Once()
+
+ // Execute
+ resp, err := rt.RoundTrip(req)
+ assert.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ body, _ := io.ReadAll(resp.Body)
+ assert.Equal(t, "Success", string(body))
+
+ mockRT.AssertExpectations(t)
+}
diff --git a/backend/plugins/github/token/token_provider.go
b/backend/plugins/github/token/token_provider.go
new file mode 100644
index 000000000..ba9941cd4
--- /dev/null
+++ b/backend/plugins/github/token/token_provider.go
@@ -0,0 +1,185 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/apache/incubator-devlake/core/dal"
+ "github.com/apache/incubator-devlake/core/errors"
+ "github.com/apache/incubator-devlake/core/log"
+ "github.com/apache/incubator-devlake/plugins/github/models"
+)
+
+const (
+ DefaultRefreshBuffer = 5 * time.Minute
+)
+
+type TokenProvider struct {
+ conn *models.GithubConnection
+ dal dal.Dal
+ httpClient *http.Client
+ logger log.Logger
+ mu sync.Mutex
+ refreshURL string
+}
+
+// NewTokenProvider creates a TokenProvider for the given GitHub connection
using
+// the provided DAL, HTTP client, and logger, and returns a pointer to it.
+func NewTokenProvider(conn *models.GithubConnection, d dal.Dal, client
*http.Client, logger log.Logger) *TokenProvider {
+ return &TokenProvider{
+ conn: conn,
+ dal: d,
+ httpClient: client,
+ logger: logger,
+ refreshURL: "https://github.com/login/oauth/access_token",
+ }
+}
+
+func (tp *TokenProvider) GetToken() (string, errors.Error) {
+ tp.mu.Lock()
+ defer tp.mu.Unlock()
+
+ if tp.needsRefresh() {
+ if err := tp.refreshToken(); err != nil {
+ return "", err
+ }
+ }
+ return tp.conn.Token, nil
+}
+
+func (tp *TokenProvider) needsRefresh() bool {
+ if tp.conn.RefreshToken == "" {
+ return false
+ }
+
+ buffer := DefaultRefreshBuffer
+ if envBuffer := os.Getenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES");
envBuffer != "" {
+ if val, err := strconv.Atoi(envBuffer); err == nil {
+ buffer = time.Duration(val) * time.Minute
+ }
+ }
+
+ return time.Now().Add(buffer).After(tp.conn.TokenExpiresAt)
+}
+
+func (tp *TokenProvider) refreshToken() errors.Error {
+ tp.logger.Info("Refreshing GitHub token for connection %d", tp.conn.ID)
+
+ data := map[string]string{
+ "refresh_token": tp.conn.RefreshToken,
+ "grant_type": "refresh_token",
+ "client_id": tp.conn.AppId,
+ "client_secret": tp.conn.SecretKey,
+ }
+ jsonData, _ := json.Marshal(data)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "POST", tp.refreshURL,
bytes.NewBuffer(jsonData))
+ if err != nil {
+ return errors.Convert(err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := tp.httpClient.Do(req)
+ if err != nil {
+ return errors.Convert(err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return errors.Convert(err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ // Log the response body to aid in debugging token refresh
failures.
+ if tp.logger != nil {
+ tp.logger.Error(nil, "failed to refresh token from
GitHub, status=%d, body=%s", resp.StatusCode, string(body))
+ }
+ return errors.Default.New(fmt.Sprintf("failed to refresh token:
%d, body: %s", resp.StatusCode, string(body)))
+ }
+ var result struct {
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ ExpiresIn int `json:"expires_in"`
+ RefreshTokenExpiresIn int `json:"refresh_token_expires_in"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return errors.Convert(err)
+ }
+
+ if result.AccessToken == "" {
+ bodyStr := string(body)
+ const maxBodySnippet = 512
+ if len(bodyStr) > maxBodySnippet {
+ bodyStr = bodyStr[:maxBodySnippet] + "…"
+ }
+ return errors.Default.New(fmt.Sprintf("empty access token
returned; response body: %s", bodyStr))
+ }
+
+ tp.conn.UpdateToken(
+ result.AccessToken,
+ result.RefreshToken,
+ time.Now().Add(time.Duration(result.ExpiresIn)*time.Second),
+
time.Now().Add(time.Duration(result.RefreshTokenExpiresIn)*time.Second),
+ )
+
+ if tp.dal != nil {
+ err := tp.dal.UpdateColumns(tp.conn, []dal.DalSet{
+ {ColumnName: "token", Value: tp.conn.Token},
+ {ColumnName: "refresh_token", Value:
tp.conn.RefreshToken},
+ {ColumnName: "token_expires_at", Value:
tp.conn.TokenExpiresAt},
+ {ColumnName: "refresh_token_expires_at", Value:
tp.conn.RefreshTokenExpiresAt},
+ })
+ if err != nil {
+ tp.logger.Warn(err, "failed to persist refreshed token")
+ }
+ }
+
+ return nil
+}
+
+// ForceRefresh refreshes the access token if the current token is still equal
to oldToken.
+// The oldToken parameter should be the token value observed by the caller
when it determined
+// that a refresh might be needed; if the token has changed since then,
another goroutine has
+// already refreshed it and this method returns without performing a redundant
refresh.
+func (tp *TokenProvider) ForceRefresh(oldToken string) errors.Error {
+ tp.mu.Lock()
+ defer tp.mu.Unlock()
+
+ // If the token has changed since the request was made, it means
another thread
+ // has already refreshed it.
+ if tp.conn.Token != oldToken {
+ return nil
+ }
+
+ return tp.refreshToken()
+}
diff --git a/backend/plugins/github/token/token_provider_test.go
b/backend/plugins/github/token/token_provider_test.go
new file mode 100644
index 000000000..1c296376a
--- /dev/null
+++ b/backend/plugins/github/token/token_provider_test.go
@@ -0,0 +1,180 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to You under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package token
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/incubator-devlake/core/errors"
+ "github.com/apache/incubator-devlake/helpers/pluginhelper/api"
+ "github.com/apache/incubator-devlake/impls/logruslog"
+ mockdal "github.com/apache/incubator-devlake/mocks/core/dal"
+ "github.com/apache/incubator-devlake/plugins/github/models"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+type MockRoundTripper struct {
+ mock.Mock
+}
+
+func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response,
error) {
+ args := m.Called(req)
+ return args.Get(0).(*http.Response), args.Error(1)
+}
+
+func TestNeedsRefresh(t *testing.T) {
+ tp := &TokenProvider{
+ conn: &models.GithubConnection{
+ GithubConn: models.GithubConn{
+ RefreshToken: "refresh_token",
+ },
+ },
+ }
+
+ // Not expired, outside buffer
+ tp.conn.TokenExpiresAt = time.Now().Add(10 * time.Minute)
+ assert.False(t, tp.needsRefresh())
+
+ // Inside buffer
+ tp.conn.TokenExpiresAt = time.Now().Add(1 * time.Minute)
+ assert.True(t, tp.needsRefresh())
+
+ // Expired
+ tp.conn.TokenExpiresAt = time.Now().Add(-1 * time.Minute)
+ assert.True(t, tp.needsRefresh())
+
+ // No refresh token
+ tp.conn.RefreshToken = ""
+ assert.False(t, tp.needsRefresh())
+}
+
+func TestTokenProviderConcurrency(t *testing.T) {
+ mockRT := new(MockRoundTripper)
+ client := &http.Client{Transport: mockRT}
+
+ conn := &models.GithubConnection{
+ GithubConn: models.GithubConn{
+ RefreshToken: "refresh_token",
+ TokenExpiresAt: time.Now().Add(-1 * time.Minute), //
Expired
+ GithubAppKey: models.GithubAppKey{
+ AppKey: api.AppKey{
+ AppId: "123",
+ SecretKey: "secret",
+ },
+ },
+ },
+ }
+
+ logger, _ := logruslog.NewDefaultLogger(logrus.New())
+ tp := NewTokenProvider(conn, nil, client, logger)
+
+ // Mock response for refresh
+ respBody :=
`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`
+ resp := &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(bytes.NewBufferString(respBody)),
+ }
+
+ // Expect exactly one call
+ mockRT.On("RoundTrip", mock.Anything).Return(resp, nil).Once()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ token, err := tp.GetToken()
+ assert.NoError(t, err)
+ assert.Equal(t, "new_token", token)
+ }()
+ }
+ wg.Wait()
+
+ mockRT.AssertExpectations(t)
+}
+
+func TestConfigurableBuffer(t *testing.T) {
+ os.Setenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES", "10")
+ defer os.Unsetenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES")
+
+ tp := &TokenProvider{
+ conn: &models.GithubConnection{
+ GithubConn: models.GithubConn{
+ RefreshToken: "refresh_token",
+ },
+ },
+ }
+
+ // 9 minutes remaining (inside 10m buffer)
+ tp.conn.TokenExpiresAt = time.Now().Add(9 * time.Minute)
+ assert.True(t, tp.needsRefresh())
+
+ // 11 minutes remaining (outside 10m buffer)
+ tp.conn.TokenExpiresAt = time.Now().Add(11 * time.Minute)
+ assert.False(t, tp.needsRefresh())
+}
+
+func TestPersistenceFailure(t *testing.T) {
+ mockRT := new(MockRoundTripper)
+ client := &http.Client{Transport: mockRT}
+ mockDal := new(mockdal.Dal)
+
+ conn := &models.GithubConnection{
+ GithubConn: models.GithubConn{
+ RefreshToken: "refresh_token",
+ GithubAccessToken: models.GithubAccessToken{
+ AccessToken: api.AccessToken{
+ Token: "old_token",
+ },
+ },
+ GithubAppKey: models.GithubAppKey{
+ AppKey: api.AppKey{
+ AppId: "123",
+ SecretKey: "secret",
+ },
+ },
+ },
+ }
+
+ logger, _ := logruslog.NewDefaultLogger(logrus.New())
+ tp := NewTokenProvider(conn, mockDal, client, logger)
+
+ // Mock response for refresh
+ respBody :=
`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`
+ resp := &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(bytes.NewBufferString(respBody)),
+ }
+ mockRT.On("RoundTrip", mock.Anything).Return(resp, nil).Once()
+
+ // Mock DAL failure
+ mockDal.On("UpdateColumns", mock.Anything, mock.Anything,
mock.AnythingOfType("[]dal.Clause")).Return(errors.Default.New("db error"))
+ err := tp.ForceRefresh("old_token")
+ assert.NoError(t, err) // Should not return error even if persistence
fails
+
+ mockRT.AssertExpectations(t)
+ mockDal.AssertExpectations(t)
+}