This is an automated email from the ASF dual-hosted git repository.

vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new cc972cad5ac fix: DB OAuth2 fixes (#37350)
cc972cad5ac is described below

commit cc972cad5ac19328d77662a4e377a10d5a8b2eac
Author: Vitor Avila <[email protected]>
AuthorDate: Thu Jan 22 01:51:48 2026 -0300

    fix: DB OAuth2 fixes (#37350)
---
 superset/db_engine_specs/base.py     | 20 ++++++---
 superset/models/core.py              | 34 ++++++++-------
 superset/utils/oauth2.py             |  5 ++-
 tests/unit_tests/models/core_test.py | 81 ++++++++++++++++++++++++++++++++++++
 4 files changed, 118 insertions(+), 22 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 4d72be41d8a..4113a3e8fe5 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -717,9 +717,13 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "redirect_uri": config["redirect_uri"],
             "grant_type": "authorization_code",
         }
-        if config["request_content_type"] == "data":
-            return requests.post(uri, data=req_body, timeout=timeout).json()
-        return requests.post(uri, json=req_body, timeout=timeout).json()
+        response = (
+            requests.post(uri, data=req_body, timeout=timeout)
+            if config["request_content_type"] == "data"
+            else requests.post(uri, json=req_body, timeout=timeout)
+        )
+        response.raise_for_status()
+        return response.json()
 
     @classmethod
     def get_oauth2_fresh_token(
@@ -738,9 +742,13 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "refresh_token": refresh_token,
             "grant_type": "refresh_token",
         }
-        if config["request_content_type"] == "data":
-            return requests.post(uri, data=req_body, timeout=timeout).json()
-        return requests.post(uri, json=req_body, timeout=timeout).json()
+        response = (
+            requests.post(uri, data=req_body, timeout=timeout)
+            if config["request_content_type"] == "data"
+            else requests.post(uri, json=req_body, timeout=timeout)
+        )
+        response.raise_for_status()
+        return response.json()
 
     @classmethod
     def get_allows_alias_in_select(
diff --git a/superset/models/core.py b/superset/models/core.py
index cb7bdf2d352..d13c14b65ab 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -896,9 +896,7 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
                     )
                 }
         except Exception as ex:
-            if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
-                self.start_oauth2_dance()
-
+            self._handle_oauth2_error(ex)
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
 
     @cache_util.memoized_func(
@@ -933,9 +931,7 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
                     )
                 }
         except Exception as ex:
-            if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
-                self.start_oauth2_dance()
-
+            self._handle_oauth2_error(ex)
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
 
     @cache_util.memoized_func(
@@ -972,9 +968,7 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
                     )
                 }
         except Exception as ex:
-            if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
-                self.start_oauth2_dance()
-
+            self._handle_oauth2_error(ex)
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
 
         return set()
@@ -1003,9 +997,7 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
             with self.get_inspector(catalog=catalog) as inspector:
                 return self.db_engine_spec.get_schema_names(inspector)
         except Exception as ex:
-            if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
-                self.start_oauth2_dance()
-
+            self._handle_oauth2_error(ex)
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
 
     @cache_util.memoized_func(
@@ -1022,9 +1014,7 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
             with self.get_inspector() as inspector:
                 return self.db_engine_spec.get_catalog_names(self, inspector)
         except Exception as ex:
-            if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
-                self.start_oauth2_dance()
-
+            self._handle_oauth2_error(ex)
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
 
     @property
@@ -1261,6 +1251,10 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
         if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
             schema = OAuth2ClientConfigSchema()
             client_config = schema.load(oauth2_client_info)
+            if "request_content_type" not in oauth2_client_info:
+                client_config["request_content_type"] = (
+                    self.db_engine_spec.oauth2_token_request_type
+                )
             return cast(OAuth2ClientConfig, client_config)
 
         return self.db_engine_spec.get_oauth2_config()
@@ -1275,6 +1269,16 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
         """
         return self.db_engine_spec.start_oauth2_dance(self)
 
+    def _handle_oauth2_error(self, ex: Exception) -> None:
+        """
+        Handle exceptions that may require OAuth2 authentication.
+
+        If OAuth2 is enabled and the exception indicates that OAuth2 is needed,
+        starts the OAuth2 dance.
+        """
+        if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
+            self.start_oauth2_dance()
+
     def purge_oauth2_tokens(self) -> None:
         """
         Delete all OAuth2 tokens associated with this database.
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index ebe1f4012eb..cd1a2a14d9e 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -189,7 +189,10 @@ class OAuth2ClientConfigSchema(Schema):
     scope = fields.String(required=True)
     redirect_uri = fields.String(
         required=False,
-        load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
+        load_default=lambda: app.config.get(
+            "DATABASE_OAUTH2_REDIRECT_URI",
+            url_for("DatabaseRestApi.oauth2", _external=True),
+        ),
     )
     authorization_request_uri = fields.String(required=True)
     token_request_uri = fields.String(required=True)
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index 998a1033bb0..7d7aa96ea19 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -660,6 +660,34 @@ def test_get_oauth2_config(app_context: None) -> None:
 
     assert database.get_oauth2_config() is None
 
+    database.encrypted_extra = json.dumps(oauth2_client_info)
+    assert database.get_oauth2_config() == {
+        "id": "my_client_id",
+        "secret": "my_client_secret",
+        "authorization_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize";,
+        "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
+        "scope": "refresh_token session:role:USERADMIN",
+        "redirect_uri": "http://example.com/api/v1/database/oauth2/";,
+        "request_content_type": "data",  # Default value from BaseEngineSpec
+    }
+
+
+def test_get_oauth2_config_token_request_type_from_db_engine_specs(
+    mocker: MockerFixture, app_context: None
+) -> None:
+    """
+    Test that DB Engine Spec overrides for ``oauth2_token_request_type`` are 
respected.
+    """
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+    mocker.patch.object(
+        database.db_engine_spec,
+        "oauth2_token_request_type",
+        "json",
+    )
+
     database.encrypted_extra = json.dumps(oauth2_client_info)
     assert database.get_oauth2_config() == {
         "id": "my_client_id",
@@ -672,6 +700,59 @@ def test_get_oauth2_config(app_context: None) -> None:
     }
 
 
+def test_get_oauth2_config_custom_token_request_type_extra(app_context: None) 
-> None:
+    """
+    Test passing a custom ``token_request_type`` via ``encrypted_extra``
+    takes precedence.
+    """
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+    custom_oauth2_client_info = {
+        "oauth2_client_info": {
+            **oauth2_client_info["oauth2_client_info"],
+            "request_content_type": "json",
+        }
+    }
+
+    database.encrypted_extra = json.dumps(custom_oauth2_client_info)
+    assert database.get_oauth2_config() == {
+        "id": "my_client_id",
+        "secret": "my_client_secret",
+        "authorization_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize";,
+        "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
+        "scope": "refresh_token session:role:USERADMIN",
+        "redirect_uri": "http://example.com/api/v1/database/oauth2/";,
+        "request_content_type": "json",
+    }
+
+
+def test_get_oauth2_config_redirect_uri_from_config(
+    mocker: MockerFixture,
+    app_context: None,
+) -> None:
+    """
+    Test that ``DATABASE_OAUTH2_REDIRECT_URI`` config takes precedence over
+    url_for default.
+    """
+    custom_redirect_uri = "https://custom.example.com/oauth/callback";
+    mocker.patch.dict(
+        "superset.utils.oauth2.app.config",
+        {"DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri},
+    )
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+    database.encrypted_extra = json.dumps(oauth2_client_info)
+
+    config = database.get_oauth2_config()
+
+    assert config is not None
+    assert config["redirect_uri"] == custom_redirect_uri
+
+
 def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
     """
     Test that we can start OAuth2 from `raw_connection()` errors.

Reply via email to