This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 5.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 92bc43e38b135bbd4bc1b365dc2f3b5935bb2fab
Author: Beto Dealmeida <[email protected]>
AuthorDate: Tue Feb 4 18:24:05 2025 -0500

    fix: move oauth2 capture to `get_sqla_engine` (#32137)
    
    (cherry picked from commit c7c3b1b0e99228f261415742469cd4a7f929da7b)
---
 superset/models/core.py              | 31 ++++++------
 superset/utils/oauth2.py             | 18 ++++++-
 tests/unit_tests/models/core_test.py | 98 +++++++++++++++++++++++++++++++++---
 3 files changed, 123 insertions(+), 24 deletions(-)

diff --git a/superset/models/core.py b/superset/models/core.py
index 6f32383ab8..96a1953fae 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -84,7 +84,11 @@ from superset.superset_typing import (
 from superset.utils import cache as cache_util, core as utils, json
 from superset.utils.backports import StrEnum
 from superset.utils.core import get_username
-from superset.utils.oauth2 import get_oauth2_access_token, 
OAuth2ClientConfigSchema
+from superset.utils.oauth2 import (
+    check_for_oauth2,
+    get_oauth2_access_token,
+    OAuth2ClientConfigSchema,
+)
 
 config = app.config
 custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@@ -451,13 +455,14 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
 
             engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
             with engine_context_manager(self, catalog, schema):
-                yield self._get_sqla_engine(
-                    catalog=catalog,
-                    schema=schema,
-                    nullpool=nullpool,
-                    source=source,
-                    sqlalchemy_uri=sqlalchemy_uri,
-                )
+                with check_for_oauth2(self):
+                    yield self._get_sqla_engine(
+                        catalog=catalog,
+                        schema=schema,
+                        nullpool=nullpool,
+                        source=source,
+                        sqlalchemy_uri=sqlalchemy_uri,
+                    )
 
     def _get_sqla_engine(  # pylint: disable=too-many-locals  # noqa: C901
         self,
@@ -583,10 +588,9 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             nullpool=nullpool,
             source=source,
         ) as engine:
-            try:
+            with check_for_oauth2(self):
                 with closing(engine.raw_connection()) as conn:
-                    # pre-session queries are used to set the selected schema 
and, in the  # noqa: E501
-                    # future, the selected catalog
+                    # pre-session queries are used to set the selected 
catalog/schema
                     for prequery in self.db_engine_spec.get_prequeries(
                         database=self,
                         catalog=catalog,
@@ -597,11 +601,6 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
 
                     yield conn
 
-            except Exception as ex:
-                if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
-                    self.db_engine_spec.start_oauth2_dance(self)
-                raise
-
     def get_default_catalog(self) -> str | None:
         """
         Return the default configured catalog for the database.
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 0918f0792a..b93f89c870 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -17,8 +17,9 @@
 
 from __future__ import annotations
 
+from contextlib import contextmanager
 from datetime import datetime, timedelta, timezone
-from typing import Any, TYPE_CHECKING
+from typing import Any, Iterator, TYPE_CHECKING
 
 import backoff
 import jwt
@@ -32,7 +33,7 @@ from superset.superset_typing import OAuth2ClientConfig, 
OAuth2State
 
 if TYPE_CHECKING:
     from superset.db_engine_specs.base import BaseEngineSpec
-    from superset.models.core import DatabaseUserOAuth2Tokens
+    from superset.models.core import Database, DatabaseUserOAuth2Tokens
 
 JWT_EXPIRATION = timedelta(minutes=5)
 
@@ -197,3 +198,16 @@ class OAuth2ClientConfigSchema(Schema):
         load_default=lambda: "json",
         validate=validate.OneOf(["json", "data"]),
     )
+
+
+@contextmanager
+def check_for_oauth2(database: Database) -> Iterator[None]:
+    """
+    Run code and check if OAuth2 is needed.
+    """
+    try:
+        yield
+    except Exception as ex:
+        if database.is_oauth2_enabled() and 
database.db_engine_spec.needs_oauth2(ex):
+            database.db_engine_spec.start_oauth2_dance(database)
+        raise
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index 5b269fc3ba..50ccde6605 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
     }
 
 
-def test_raw_connection_oauth(mocker: MockerFixture) -> None:
+def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
     """
     Test that we can start OAuth2 from `raw_connection()` errors.
 
-    Some databases that use OAuth2 need to trigger the flow when the 
connection is
-    created, rather than when the query runs. This happens when the SQLAlchemy 
engine
-    URI cannot be built without the user personal token.
+    With OAuth2, some databases will raise an exception when the engine is 
first created
+    (eg, BigQuery). Others, like, Snowflake, when the connection is created. 
And
+    finally, GSheets will raise an exception when the query is executed.
 
-    This test verifies that the exception is captured and raised correctly so 
that the
-    frontend can trigger the OAuth2 dance.
+    This tests verifies that when calling `raw_connection()` the OAuth2 flow is
+    triggered when the engine is created.
+    """
+    g = mocker.patch("superset.db_engine_specs.base.g")
+    g.user = mocker.MagicMock()
+    g.user.id = 42
+
+    database = Database(
+        id=1,
+        database_name="my_db",
+        sqlalchemy_uri="sqlite://",
+        encrypted_extra=json.dumps(oauth2_client_info),
+    )
+    database.db_engine_spec.oauth2_exception = OAuth2Error  # type: ignore
+    _get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
+    _get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
+
+    with pytest.raises(OAuth2RedirectError) as excinfo:
+        with database.get_raw_connection() as conn:
+            conn.cursor()
+    assert str(excinfo.value) == "You don't have permission to access the 
data."
+
+
+def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
+    """
+    Test that we can start OAuth2 from `raw_connection()` errors.
+
+    With OAuth2, some databases will raise an exception when the engine is 
first created
+    (eg, BigQuery). Others, like, Snowflake, when the connection is created. 
And
+    finally, GSheets will raise an exception when the query is executed.
+
+    This tests verifies that when calling `raw_connection()` the OAuth2 flow is
+    triggered when the connection is created.
     """
     g = mocker.patch("superset.db_engine_specs.base.g")
     g.user = mocker.MagicMock()
@@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> 
None:
     assert str(excinfo.value) == "You don't have permission to access the 
data."
 
 
+def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
+    """
+    Test that we can start OAuth2 from `raw_connection()` errors.
+
+    With OAuth2, some databases will raise an exception when the engine is 
first created
+    (eg, BigQuery). Others, like, Snowflake, when the connection is created. 
And
+    finally, GSheets will raise an exception when the query is executed.
+
+    This tests verifies that when calling `raw_connection()` the OAuth2 flow is
+    triggered when the connection is created.
+    """
+    g = mocker.patch("superset.db_engine_specs.base.g")
+    g.user = mocker.MagicMock()
+    g.user.id = 42
+
+    database = Database(
+        id=1,
+        database_name="my_db",
+        sqlalchemy_uri="sqlite://",
+        encrypted_extra=json.dumps(oauth2_client_info),
+    )
+    database.db_engine_spec.oauth2_exception = OAuth2Error  # type: ignore
+    get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
+    
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
+        OAuth2Error("OAuth2 required")
+    )
+
+    with pytest.raises(OAuth2RedirectError) as excinfo:  # noqa: PT012
+        with database.get_raw_connection() as conn:
+            cursor = conn.cursor()
+            cursor.execute("SELECT 1")
+    assert str(excinfo.value) == "You don't have permission to access the 
data."
+
+
 def test_get_schema_access_for_file_upload() -> None:
     """
     Test the `get_schema_access_for_file_upload` method.
@@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> 
None:
     )
 
 
+def test_engine_oauth2(mocker: MockerFixture) -> None:
+    """
+    Test that we handle OAuth2 when `create_engine` fails.
+    """
+    database = Database(database_name="my_db", sqlalchemy_uri="trino://")
+    mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
+    mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
+    mocker.patch.object(database.db_engine_spec, "needs_oauth2", 
return_value=True)
+    start_oauth2_dance = mocker.patch.object(
+        database.db_engine_spec,
+        "start_oauth2_dance",
+        side_effect=OAuth2Error("OAuth2 required"),
+    )
+
+    with pytest.raises(OAuth2Error):
+        with database.get_sqla_engine("catalog", "schema"):
+            pass
+
+    start_oauth2_dance.assert_called_with(database)
+
+
 def test_purge_oauth2_tokens(session: Session) -> None:
     """
     Test the `purge_oauth2_tokens` method.

Reply via email to