This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5660f8e554 feat: OAuth2 client initial work (#29109)
5660f8e554 is described below

commit 5660f8e5542b78e098c42306633c182d9a631d63
Author: Beto Dealmeida <[email protected]>
AuthorDate: Sun Jun 9 22:11:18 2024 -0400

    feat: OAuth2 client initial work (#29109)
---
 superset/db_engine_specs/base.py     | 21 ++++++---
 superset/models/core.py              | 54 ++++++++++++++---------
 superset/utils/oauth2.py             | 14 +++++-
 tests/unit_tests/conftest.py         |  3 ++
 tests/unit_tests/models/core_test.py | 83 ++++++++++++++++++++++++++++++++++++
 tests/unit_tests/sql_lab_test.py     | 62 +++++++++++++++++++++++++++
 6 files changed, 210 insertions(+), 27 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 548fb390d8..cd37e4e602 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1610,9 +1610,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     @classmethod
     def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
         return [
-            literal_column(query_as)
-            if (query_as := c.get("query_as"))
-            else column(c["column_name"])
+            (
+                literal_column(query_as)
+                if (query_as := c.get("query_as"))
+                else column(c["column_name"])
+            )
             for c in cols
         ]
 
@@ -1828,13 +1830,18 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             cursor.arraysize = cls.arraysize
         try:
             cursor.execute(query)
-        except cls.oauth2_exception as ex:
-            if database.is_oauth2_enabled() and g and g.user:
-                cls.start_oauth2_dance(database)
-            raise cls.get_dbapi_mapped_exception(ex) from ex
         except Exception as ex:
+            if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
+                cls.start_oauth2_dance(database)
             raise cls.get_dbapi_mapped_exception(ex) from ex
 
+    @classmethod
+    def needs_oauth2(cls, ex: Exception) -> bool:
+        """
+        Check if the exception is one that indicates OAuth2 is needed.
+        """
+        return g and hasattr(g, "user") and isinstance(ex, 
cls.oauth2_exception)
+
     @classmethod
     def make_label_compatible(cls, label: str) -> str | quoted_name:
         """
diff --git a/superset/models/core.py b/superset/models/core.py
index e6d97a197b..c8c875e435 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -29,7 +29,7 @@ from contextlib import closing, contextmanager, nullcontext, 
suppress
 from copy import deepcopy
 from datetime import datetime
 from functools import lru_cache
-from typing import Any, Callable, TYPE_CHECKING
+from typing import Any, Callable, cast, TYPE_CHECKING
 
 import numpy
 import pandas as pd
@@ -78,7 +78,7 @@ from superset.superset_typing import OAuth2ClientConfig, 
ResultSetColumnType
 from superset.utils import cache as cache_util, core as utils, json
 from superset.utils.backports import StrEnum
 from superset.utils.core import DatasourceName, get_username
-from superset.utils.oauth2 import get_oauth2_access_token
+from superset.utils.oauth2 import get_oauth2_access_token, 
OAuth2ClientConfigSchema
 
 config = app.config
 custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@@ -554,17 +554,23 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             nullpool=nullpool,
             source=source,
         ) as engine:
-            with closing(engine.raw_connection()) as conn:
-                # pre-session queries are used to set the selected schema and, 
in the
-                # future, the selected catalog
-                for prequery in self.db_engine_spec.get_prequeries(
-                    catalog=catalog,
-                    schema=schema,
-                ):
-                    cursor = conn.cursor()
-                    cursor.execute(prequery)
+            try:
+                with closing(engine.raw_connection()) as conn:
+                    # pre-session queries are used to set the selected schema 
and, in the
+                    # future, the selected catalog
+                    for prequery in self.db_engine_spec.get_prequeries(
+                        catalog=catalog,
+                        schema=schema,
+                    ):
+                        cursor = conn.cursor()
+                        cursor.execute(prequery)
 
-                yield conn
+                    yield conn
+
+            except Exception as ex:
+                if self.is_oauth2_enabled() and 
self.db_engine_spec.needs_oauth2(ex):
+                    self.db_engine_spec.start_oauth2_dance(self)
+                raise ex
 
     def get_default_catalog(self) -> str | None:
         """
@@ -1063,20 +1069,30 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         """
         Is OAuth2 enabled in the database for authentication?
 
-        Currently this looks for a global config at the DB engine spec level, 
but in the
-        future we want to be allow admins to create custom OAuth2 clients from 
the
-        Superset UI, and assign them to specific databases.
+        Currently this checks for configuration stored in the database 
`extra`, and then
+        for a global config at the DB engine spec level. In the future we want 
to allow
+        admins to create custom OAuth2 clients from the Superset UI, and 
assign them to
+        specific databases.
         """
-        return self.db_engine_spec.is_oauth2_enabled()
+        encrypted_extra = json.loads(self.encrypted_extra or "{}")
+        oauth2_client_info = encrypted_extra.get("oauth2_client_info", {})
+        return bool(oauth2_client_info) or 
self.db_engine_spec.is_oauth2_enabled()
 
     def get_oauth2_config(self) -> OAuth2ClientConfig | None:
         """
         Return OAuth2 client configuration.
 
-        This includes client ID, client secret, scope, redirect URI, 
endpointsm etc.
-        Currently this reads the global DB engine spec config, but in the 
future it
-        should first check if there's a custom client assigned to the database.
+        Currently this checks for configuration stored in the database 
`extra`, and then
+        for a global config at the DB engine spec level. In the future we want 
to allow
+        admins to create custom OAuth2 clients from the Superset UI, and 
assign them to
+        specific databases.
         """
+        encrypted_extra = json.loads(self.encrypted_extra or "{}")
+        if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
+            schema = OAuth2ClientConfigSchema()
+            client_config = schema.load(oauth2_client_info)
+            return cast(OAuth2ClientConfig, client_config)
+
         return self.db_engine_spec.get_oauth2_config()
 
 
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 9cc58a0b7f..bc4805fd81 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -22,7 +22,7 @@ from typing import Any, TYPE_CHECKING
 
 import backoff
 import jwt
-from flask import current_app
+from flask import current_app, url_for
 from marshmallow import EXCLUDE, fields, post_load, Schema
 
 from superset import db
@@ -180,3 +180,15 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State:
     state = oauth2_state_schema.load(payload)
 
     return state
+
+
+class OAuth2ClientConfigSchema(Schema):
+    id = fields.String(required=True)
+    secret = fields.String(required=True)
+    scope = fields.String(required=True)
+    redirect_uri = fields.String(
+        required=False,
+        load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
+    )
+    authorization_request_uri = fields.String(required=True)
+    token_request_uri = fields.String(required=True)
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index 3905a15b32..2b8f39d6dd 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -90,6 +90,9 @@ def app(request: SubRequest) -> Iterator[SupersetApp]:
     app.config["RATELIMIT_ENABLED"] = False
     app.config["CACHE_CONFIG"] = {}
     app.config["DATA_CACHE_CONFIG"] = {}
+    app.config["SERVER_NAME"] = "example.com"
+    app.config["APPLICATION_ROOT"] = "/"
+    app.config["PREFERRED_URL_SCHEME="] = "http"
 
     # loop over extra configs passed in by tests
     # and update the app config
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index 2004ff482f..c4d642baf5 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -16,6 +16,7 @@
 # under the License.
 
 # pylint: disable=import-outside-toplevel
+
 from datetime import datetime
 
 import pytest
@@ -24,11 +25,23 @@ from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import make_url
 
 from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.exceptions import OAuth2Error, OAuth2RedirectError
 from superset.models.core import Database
 from superset.sql_parse import Table
 from superset.utils import json
 from tests.unit_tests.conftest import with_feature_flags
 
+# sample config for OAuth2 tests
+oauth2_client_info = {
+    "oauth2_client_info": {
+        "id": "my_client_id",
+        "secret": "my_client_secret",
+        "authorization_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize";,
+        "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
+        "scope": "refresh_token session:role:SYSADMIN",
+    }
+}
+
 
 def test_get_metrics(mocker: MockFixture) -> None:
     """
@@ -378,3 +391,73 @@ def test_get_sqla_engine_user_impersonation_email(mocker: 
MockFixture) -> None:
         make_url("trino:///"),
         connect_args={"user": "alice.doe", "source": "Apache Superset"},
     )
+
+
+def test_is_oauth2_enabled() -> None:
+    """
+    Test the `is_oauth2_enabled` method.
+    """
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+
+    assert not database.is_oauth2_enabled()
+
+    database.encrypted_extra = json.dumps(oauth2_client_info)
+    assert database.is_oauth2_enabled()
+
+
+def test_get_oauth2_config(app_context: None) -> None:
+    """
+    Test the `get_oauth2_config` method.
+    """
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+
+    assert database.get_oauth2_config() is None
+
+    database.encrypted_extra = json.dumps(oauth2_client_info)
+    assert database.get_oauth2_config() == {
+        "id": "my_client_id",
+        "secret": "my_client_secret",
+        "authorization_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize";,
+        "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
+        "scope": "refresh_token session:role:SYSADMIN",
+        "redirect_uri": "http://example.com/api/v1/database/oauth2/";,
+    }
+
+
+def test_raw_connection_oauth(mocker: MockFixture) -> None:
+    """
+    Test that we can start OAuth2 from `raw_connection()` errors.
+
+    Some databases that use OAuth2 need to trigger the flow when the 
connection is
+    created, rather than when the query runs. This happens when the SQLAlchemy 
engine
+    URI cannot be built without the user personal token.
+
+    This test verifies that the exception is captured and raised correctly so 
that the
+    frontend can trigger the OAuth2 dance.
+    """
+    g = mocker.patch("superset.db_engine_specs.base.g")
+    g.user = mocker.MagicMock()
+    g.user.id = 42
+
+    database = Database(
+        id=1,
+        database_name="my_db",
+        sqlalchemy_uri="sqlite://",
+        encrypted_extra=json.dumps(oauth2_client_info),
+    )
+    database.db_engine_spec.oauth2_exception = OAuth2Error  # type: ignore
+    get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
+    get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
+        "OAuth2 required"
+    )
+
+    with pytest.raises(OAuth2RedirectError) as excinfo:
+        with database.get_raw_connection() as conn:
+            conn.cursor()
+    assert str(excinfo.value) == "You don't have permission to access the 
data."
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 3b2e7690e1..6ac055df13 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -16,12 +16,22 @@
 # under the License.
 # pylint: disable=import-outside-toplevel, invalid-name, unused-argument, 
too-many-locals
 
+import json
+from uuid import UUID
+
 import sqlparse
+from freezegun import freeze_time
 from pytest_mock import MockerFixture
 from sqlalchemy.orm.session import Session
 
 from superset import db
+from superset.common.db_query_status import QueryStatus
+from superset.errors import ErrorLevel, SupersetErrorType
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database
+from superset.sql_lab import get_sql_results
 from superset.utils.core import override_user
+from tests.unit_tests.models.core_test import oauth2_client_info
 
 
 def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
@@ -218,3 +228,55 @@ def test_sql_lab_insert_rls_as_subquery(
         query.executed_sql
         == "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6"
     )
+
+
+@freeze_time("2021-04-01T00:00:00Z")
+def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
+    """
+    Test that `get_sql_results` works with OAuth2.
+    """
+    app_context = app.test_request_context()
+    app_context.push()
+
+    mocker.patch(
+        "superset.db_engine_specs.base.uuid4",
+        return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
+    )
+
+    g = mocker.patch("superset.db_engine_specs.base.g")
+    g.user = mocker.MagicMock()
+    g.user.id = 42
+
+    database = Database(
+        id=1,
+        database_name="my_db",
+        sqlalchemy_uri="sqlite://",
+        encrypted_extra=json.dumps(oauth2_client_info),
+    )
+    database.db_engine_spec.oauth2_exception = OAuth2Error  # type: ignore
+    get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
+    get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
+        "OAuth2 required"
+    )
+
+    query = mocker.MagicMock()
+    query.database = database
+    mocker.patch("superset.sql_lab.get_query", return_value=query)
+
+    payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
+    assert payload == {
+        "status": QueryStatus.FAILED,
+        "error": "You don't have permission to access the data.",
+        "errors": [
+            {
+                "message": "You don't have permission to access the data.",
+                "error_type": SupersetErrorType.OAUTH2_REDIRECT,
+                "level": ErrorLevel.WARNING,
+                "extra": {
+                    "url": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO192
 [...]
+                    "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
+                    "redirect_uri": 
"http://example.com/api/v1/database/oauth2/";,
+                },
+            }
+        ],
+    }

Reply via email to