This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 305b6df6e3 feat(oauth2): add support for trino (#30081)
305b6df6e3 is described below

commit 305b6df6e3e5aaa6d3faa8fa1a2d91fcb05b7c34
Author: João Ferrão <[email protected]>
AuthorDate: Mon Nov 4 17:54:47 2024 +0100

    feat(oauth2): add support for trino (#30081)
---
 superset/db_engine_specs/base.py                 |  46 ++++-----
 superset/db_engine_specs/trino.py                |  27 +++++-
 superset/superset_typing.py                      |   4 +
 superset/utils/oauth2.py                         |   7 +-
 tests/unit_tests/db_engine_specs/test_gsheets.py |   1 +
 tests/unit_tests/db_engine_specs/test_trino.py   | 117 +++++++++++++++++------
 tests/unit_tests/models/core_test.py             |   1 +
 7 files changed, 151 insertions(+), 52 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index a086f6eff7..8cabb1e589 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -433,6 +433,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     oauth2_scope = ""
     oauth2_authorization_request_uri: str | None = None  # pylint: 
disable=invalid-name
     oauth2_token_request_uri: str | None = None
+    oauth2_token_request_type = "data"
 
     # Driver-specific exception that should be mapped to OAuth2RedirectError
     oauth2_exception = OAuth2RedirectError
@@ -525,6 +526,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
                 "token_request_uri",
                 cls.oauth2_token_request_uri,
             ),
+            "request_content_type": db_engine_spec_config.get(
+                "request_content_type", cls.oauth2_token_request_type
+            ),
         }
 
         return config
@@ -562,18 +566,16 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
         uri = config["token_request_uri"]
-        response = requests.post(
-            uri,
-            json={
-                "code": code,
-                "client_id": config["id"],
-                "client_secret": config["secret"],
-                "redirect_uri": config["redirect_uri"],
-                "grant_type": "authorization_code",
-            },
-            timeout=timeout,
-        )
-        return response.json()
+        req_body = {
+            "code": code,
+            "client_id": config["id"],
+            "client_secret": config["secret"],
+            "redirect_uri": config["redirect_uri"],
+            "grant_type": "authorization_code",
+        }
+        if config["request_content_type"] == "data":
+            return requests.post(uri, data=req_body, timeout=timeout).json()
+        return requests.post(uri, json=req_body, timeout=timeout).json()
 
     @classmethod
     def get_oauth2_fresh_token(
@@ -586,17 +588,15 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
         uri = config["token_request_uri"]
-        response = requests.post(
-            uri,
-            json={
-                "client_id": config["id"],
-                "client_secret": config["secret"],
-                "refresh_token": refresh_token,
-                "grant_type": "refresh_token",
-            },
-            timeout=timeout,
-        )
-        return response.json()
+        req_body = {
+            "client_id": config["id"],
+            "client_secret": config["secret"],
+            "refresh_token": refresh_token,
+            "grant_type": "refresh_token",
+        }
+        if config["request_content_type"] == "data":
+            return requests.post(uri, data=req_body, timeout=timeout).json()
+        return requests.post(uri, json=req_body, timeout=timeout).json()
 
     @classmethod
     def get_allows_alias_in_select(
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index c473528217..ad00557f65 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING
 import numpy as np
 import pandas as pd
 import pyarrow as pa
-from flask import ctx, current_app, Flask, g
+import requests
+from flask import copy_current_request_context, ctx, current_app, Flask, g
 from sqlalchemy import text
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 from sqlalchemy.exc import NoSuchTableError
+from trino.exceptions import HttpError
 
 from superset import db
 from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, 
USER_AGENT
@@ -60,11 +62,28 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class CustomTrinoAuthErrorMeta(type):
+    def __instancecheck__(cls, instance: object) -> bool:
+        logger.info("is this being called?")
+        return isinstance(
+            instance, HttpError
+        ) and "error 401: b'Invalid credentials'" in str(instance)
+
+
+class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
+    pass
+
+
 class TrinoEngineSpec(PrestoBaseEngineSpec):
     engine = "trino"
     engine_name = "Trino"
     allows_alias_to_source_column = False
 
+    # OAuth 2.0
+    supports_oauth2 = True
+    oauth2_exception = TrinoAuthError
+    oauth2_token_request_type = "data"
+
     @classmethod
     def get_extra_table_metadata(
         cls,
@@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         # Set principal_username=$effective_username
         if backend_name == "trino" and username is not None:
             connect_args["user"] = username
+            if access_token is not None:
+                http_session = requests.Session()
+                http_session.headers.update({"Authorization": f"Bearer 
{access_token}"})
+                connect_args["http_session"] = http_session
 
     @classmethod
     def get_url_for_impersonation(
@@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         """
         Return a modified URL with the username set.
 
+        :param access_token: Personal access token for OAuth2
         :param url: SQLAlchemy URL object
         :param impersonate_user: Flag indicating if impersonation is enabled
         :param username: Effective username
@@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         execute_result: dict[str, Any] = {}
         execute_event = threading.Event()
 
+        @copy_current_request_context
         def _execute(
             results: dict[str, Any],
             event: threading.Event,
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 3a850e0acb..c3c40cd31a 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
     # expired access token.
     token_request_uri: str
 
+    # Not all identity providers expect json. Keycloak expects a form encoded 
request,
+    # which in the `requests` package context means using the `data` param, 
not `json`.
+    request_content_type: str
+
 
 class OAuth2TokenResponse(TypedDict, total=False):
     """
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index b889ef83c5..95db2921f6 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING
 import backoff
 import jwt
 from flask import current_app, url_for
-from marshmallow import EXCLUDE, fields, post_load, Schema
+from marshmallow import EXCLUDE, fields, post_load, Schema, validate
 
 from superset import db
 from superset.distributed_lock import KeyValueDistributedLock
@@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
     )
     authorization_request_uri = fields.String(required=True)
     token_request_uri = fields.String(required=True)
+    request_content_type = fields.String(
+        required=False,
+        load_default=lambda: "json",
+        validate=validate.OneOf(["json", "data"]),
+    )
diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py 
b/tests/unit_tests/db_engine_specs/test_gsheets.py
index 5d2ddb807b..4e17054db9 100644
--- a/tests/unit_tests/db_engine_specs/test_gsheets.py
+++ b/tests/unit_tests/db_engine_specs/test_gsheets.py
@@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
         "redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
         "authorization_request_uri": 
"https://accounts.google.com/o/oauth2/v2/auth";,
         "token_request_uri": "https://oauth2.googleapis.com/token";,
+        "request_content_type": "json",
     }
 
 
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py 
b/tests/unit_tests/db_engine_specs/test_trino.py
index 5a32cd0504..b616adfcf1 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import (
     SupersetDBAPIProgrammingError,
 )
 from superset.sql_parse import Table
-from superset.superset_typing import ResultSetColumnType, SQLAColumnType, 
SQLType
+from superset.superset_typing import (
+    OAuth2ClientConfig,
+    ResultSetColumnType,
+    SQLAColumnType,
+    SQLType,
+)
 from superset.utils import json
 from superset.utils.core import GenericDataType
 from tests.unit_tests.db_engine_specs.utils import (
@@ -421,21 +426,23 @@ def test_execute_with_cursor_in_parallel(app, mocker: 
MockerFixture):
     def _mock_execute(*args, **kwargs):
         mock_cursor.query_id = query_id
 
-    mock_cursor.execute.side_effect = _mock_execute
-    with patch.dict(
-        "superset.config.DISALLOWED_SQL_FUNCTIONS",
-        {},
-        clear=True,
-    ):
-        TrinoEngineSpec.execute_with_cursor(
-            cursor=mock_cursor,
-            sql="SELECT 1 FROM foo",
-            query=mock_query,
-        )
+    with app.test_request_context("/some/place/"):
+        mock_cursor.execute.side_effect = _mock_execute
 
-        mock_query.set_extra_json_key.assert_called_once_with(
-            key=QUERY_CANCEL_KEY, value=query_id
-        )
+        with patch.dict(
+            "superset.config.DISALLOWED_SQL_FUNCTIONS",
+            {},
+            clear=True,
+        ):
+            TrinoEngineSpec.execute_with_cursor(
+                cursor=mock_cursor,
+                sql="SELECT 1 FROM foo",
+                query=mock_query,
+            )
+
+            mock_query.set_extra_json_key.assert_called_once_with(
+                key=QUERY_CANCEL_KEY, value=query_id
+            )
 
 
 def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
@@ -446,23 +453,25 @@ def test_execute_with_cursor_app_context(app, mocker: 
MockerFixture):
     mock_cursor.query_id = None
 
     mock_query = mocker.MagicMock()
-    g.some_value = "some_value"
 
     def _mock_execute(*args, **kwargs):
         assert has_app_context()
         assert g.some_value == "some_value"
 
-    with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
-        with patch.dict(
-            "superset.config.DISALLOWED_SQL_FUNCTIONS",
-            {},
-            clear=True,
-        ):
-            TrinoEngineSpec.execute_with_cursor(
-                cursor=mock_cursor,
-                sql="SELECT 1 FROM foo",
-                query=mock_query,
-            )
+    with app.test_request_context("/some/place/"):
+        g.some_value = "some_value"
+
+        with patch.object(TrinoEngineSpec, "execute", 
side_effect=_mock_execute):
+            with patch.dict(
+                "superset.config.DISALLOWED_SQL_FUNCTIONS",
+                {},
+                clear=True,
+            ):
+                TrinoEngineSpec.execute_with_cursor(
+                    cursor=mock_cursor,
+                    sql="SELECT 1 FROM foo",
+                    query=mock_query,
+                )
 
 
 def test_get_columns(mocker: MockerFixture):
@@ -784,3 +793,57 @@ def test_where_latest_partition(
         )
         == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
     )
+
+
[email protected]
+def oauth2_config() -> OAuth2ClientConfig:
+    """
+    Config for Trino OAuth2.
+    """
+    return {
+        "id": "trino",
+        "secret": "very-secret",
+        "scope": "",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": 
"https://trino.auth.server.example/realms/master/protocol/openid-connect/auth";,
+        "token_request_uri": 
"https://trino.auth.server.example/master/protocol/openid-connect/token";,
+        "request_content_type": "data",
+    }
+
+
+def test_get_oauth2_token(
+    mocker: MockerFixture,
+    oauth2_config: OAuth2ClientConfig,
+) -> None:
+    """
+    Test `get_oauth2_token`.
+    """
+    from superset.db_engine_specs.trino import TrinoEngineSpec
+
+    requests = mocker.patch("superset.db_engine_specs.base.requests")
+    requests.post().json.return_value = {
+        "access_token": "access-token",
+        "expires_in": 3600,
+        "scope": "scope",
+        "token_type": "Bearer",
+        "refresh_token": "refresh-token",
+    }
+
+    assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
+        "access_token": "access-token",
+        "expires_in": 3600,
+        "scope": "scope",
+        "token_type": "Bearer",
+        "refresh_token": "refresh-token",
+    }
+    requests.post.assert_called_with(
+        
"https://trino.auth.server.example/master/protocol/openid-connect/token";,
+        data={
+            "code": "code",
+            "client_id": "trino",
+            "client_secret": "very-secret",
+            "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+            "grant_type": "authorization_code",
+        },
+        timeout=30.0,
+    )
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index 452cbb6f56..1dff4784ec 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
         "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
         "scope": "refresh_token session:role:USERADMIN",
         "redirect_uri": "http://example.com/api/v1/database/oauth2/";,
+        "request_content_type": "json",
     }
 
 

Reply via email to