This is an automated email from the ASF dual-hosted git repository.

Vitor-Avila pushed a commit to branch fix/oauth2-race-condition-bug
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 47676a35eb4cddd1b58565a4fad8165916732364
Author: Vitor Avila <[email protected]>
AuthorDate: Tue May 12 12:29:55 2026 -0300

    fix(OAuth2): Re-query the OAuth2 token to avoid stale reference
---
 superset/utils/oauth2.py               | 17 ++++++++++++---
 tests/unit_tests/utils/oauth2_tests.py | 38 ++++++++++++++++++++++++++++------
 2 files changed, 46 insertions(+), 9 deletions(-)

diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index a2a2666e7c0..48068f0053c 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -38,7 +38,7 @@ from superset.superset_typing import OAuth2ClientConfig, 
OAuth2State
 
 if TYPE_CHECKING:
     from superset.db_engine_specs.base import BaseEngineSpec
-    from superset.models.core import Database, DatabaseUserOAuth2Tokens
+    from superset.models.core import Database
 
 JWT_EXPIRATION = timedelta(minutes=5)
 
@@ -116,7 +116,7 @@ def get_oauth2_access_token(
         return token.access_token
 
     if token.refresh_token:
-        return refresh_oauth2_token(config, database_id, user_id, 
db_engine_spec, token)
+        return refresh_oauth2_token(config, database_id, user_id, 
db_engine_spec)
 
     # since the access token is expired and there's no refresh token, delete 
the entry
     db.session.delete(token)
@@ -129,8 +129,10 @@ def refresh_oauth2_token(
     database_id: int,
     user_id: int,
     db_engine_spec: type[BaseEngineSpec],
-    token: DatabaseUserOAuth2Tokens,
 ) -> str | None:
+    # pylint: disable=import-outside-toplevel
+    from superset.models.core import DatabaseUserOAuth2Tokens
+
     # Use longer TTL for OAuth2 token refresh (may involve network calls)
     with DistributedLock(
         namespace="refresh_oauth2_token",
@@ -138,6 +140,15 @@ def refresh_oauth2_token(
         user_id=user_id,
         database_id=database_id,
     ):
+        # Short circuit in case another request already deleted the token
+        token = (
+            db.session.query(DatabaseUserOAuth2Tokens)
+            .filter_by(user_id=user_id, database_id=database_id)
+            .one_or_none()
+        )
+        if token is None:
+            return None
+
         try:
             token_response = db_engine_spec.get_oauth2_fresh_token(
                 config,
diff --git a/tests/unit_tests/utils/oauth2_tests.py 
b/tests/unit_tests/utils/oauth2_tests.py
index c74b2f9570b..57ae95804a7 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -132,9 +132,10 @@ def 
test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
     )
     token = mocker.MagicMock()
     token.refresh_token = "refresh-token"  # noqa: S105
+    db.session.query().filter_by().one_or_none.return_value = token
 
     with pytest.raises(OAuth2ExceptionError):
-        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
 
     db.session.delete.assert_called_with(token)
     db.session.flush.assert_called_once()
@@ -161,9 +162,10 @@ def 
test_refresh_oauth2_token_keeps_token_on_other_exception(
     db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network 
error")
     token = mocker.MagicMock()
     token.refresh_token = "refresh-token"  # noqa: S105
+    db.session.query().filter_by().one_or_none.return_value = token
 
     with pytest.raises(Exception, match="Network error"):
-        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
 
     db.session.delete.assert_not_called()
 
@@ -176,7 +178,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
 
     This can happen when the refresh token was revoked.
     """
-    mocker.patch("superset.utils.oauth2.db")
+    db = mocker.patch("superset.utils.oauth2.db")
     mocker.patch("superset.utils.oauth2.DistributedLock")
     db_engine_spec = mocker.MagicMock()
     db_engine_spec.get_oauth2_fresh_token.return_value = {
@@ -184,8 +186,9 @@ def test_refresh_oauth2_token_no_access_token_in_response(
     }
     token = mocker.MagicMock()
     token.refresh_token = "refresh-token"  # noqa: S105
+    db.session.query().filter_by().one_or_none.return_value = token
 
-    result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, 
token)
+    result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
 
     assert result is None
 
@@ -209,9 +212,10 @@ def test_refresh_oauth2_token_updates_refresh_token(
     }
     token = mocker.MagicMock()
     token.refresh_token = "old-refresh-token"  # noqa: S105
+    db.session.query().filter_by().one_or_none.return_value = token
 
     with freeze_time("2024-01-01"):
-        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
 
     assert token.access_token == "new-access-token"  # noqa: S105
     assert token.access_token_expiration == datetime(2024, 1, 1, 1)
@@ -237,15 +241,37 @@ def test_refresh_oauth2_token_keeps_refresh_token(
     }
     token = mocker.MagicMock()
     token.refresh_token = "original-refresh-token"  # noqa: S105
+    db.session.query().filter_by().one_or_none.return_value = token
 
     with freeze_time("2024-01-01"):
-        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
 
     assert token.access_token == "new-access-token"  # noqa: S105
     assert token.refresh_token == "original-refresh-token"  # noqa: S105
     db.session.add.assert_called_with(token)
 
 
+def test_refresh_oauth2_token_returns_none_when_row_deleted_under_lock(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that refresh_oauth2_token returns None when the row is gone under the 
lock.
+
+    When concurrent requests are triggered and the first one deletes the token 
row and
+    releases the lock before the second one gets to `refresh_oauth2_token`, 
the token
+    is queried again to avoid a stale reference.
+    """
+    db = mocker.patch("superset.utils.oauth2.db")
+    mocker.patch("superset.utils.oauth2.DistributedLock")
+    db_engine_spec = mocker.MagicMock()
+    db.session.query().filter_by().one_or_none.return_value = None
+
+    result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
+
+    assert result is None
+    db_engine_spec.get_oauth2_fresh_token.assert_not_called()
+
+
 def test_generate_code_verifier_length() -> None:
     """
     Test that generate_code_verifier produces a string of valid length (RFC 
7636).

Reply via email to