This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 305b6df6e3 feat(oauth2): add support for trino (#30081)
305b6df6e3 is described below
commit 305b6df6e3e5aaa6d3faa8fa1a2d91fcb05b7c34
Author: João Ferrão <[email protected]>
AuthorDate: Mon Nov 4 17:54:47 2024 +0100
feat(oauth2): add support for trino (#30081)
---
superset/db_engine_specs/base.py | 46 ++++-----
superset/db_engine_specs/trino.py | 27 +++++-
superset/superset_typing.py | 4 +
superset/utils/oauth2.py | 7 +-
tests/unit_tests/db_engine_specs/test_gsheets.py | 1 +
tests/unit_tests/db_engine_specs/test_trino.py | 117 +++++++++++++++++------
tests/unit_tests/models/core_test.py | 1 +
7 files changed, 151 insertions(+), 52 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index a086f6eff7..8cabb1e589 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint:
disable=invalid-name
oauth2_token_request_uri: str | None = None
+ oauth2_token_request_type = "data"
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
@@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"token_request_uri",
cls.oauth2_token_request_uri,
),
+ "request_content_type": db_engine_spec_config.get(
+ "request_content_type", cls.oauth2_token_request_type
+ ),
}
return config
@@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
- response = requests.post(
- uri,
- json={
- "code": code,
- "client_id": config["id"],
- "client_secret": config["secret"],
- "redirect_uri": config["redirect_uri"],
- "grant_type": "authorization_code",
- },
- timeout=timeout,
- )
- return response.json()
+ req_body = {
+ "code": code,
+ "client_id": config["id"],
+ "client_secret": config["secret"],
+ "redirect_uri": config["redirect_uri"],
+ "grant_type": "authorization_code",
+ }
+ if config["request_content_type"] == "data":
+ return requests.post(uri, data=req_body, timeout=timeout).json()
+ return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_oauth2_fresh_token(
@@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
- response = requests.post(
- uri,
- json={
- "client_id": config["id"],
- "client_secret": config["secret"],
- "refresh_token": refresh_token,
- "grant_type": "refresh_token",
- },
- timeout=timeout,
- )
- return response.json()
+ req_body = {
+ "client_id": config["id"],
+ "client_secret": config["secret"],
+ "refresh_token": refresh_token,
+ "grant_type": "refresh_token",
+ }
+ if config["request_content_type"] == "data":
+ return requests.post(uri, data=req_body, timeout=timeout).json()
+ return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_allows_alias_in_select(
diff --git a/superset/db_engine_specs/trino.py
b/superset/db_engine_specs/trino.py
index c473528217..ad00557f65 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING
import numpy as np
import pandas as pd
import pyarrow as pa
-from flask import ctx, current_app, Flask, g
+import requests
+from flask import copy_current_request_context, ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
+from trino.exceptions import HttpError
from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY,
USER_AGENT
@@ -60,11 +62,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class CustomTrinoAuthErrorMeta(type):
+ def __instancecheck__(cls, instance: object) -> bool:
+ logger.info("is this being called?")
+ return isinstance(
+ instance, HttpError
+ ) and "error 401: b'Invalid credentials'" in str(instance)
+
+
+class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
+ pass
+
+
class TrinoEngineSpec(PrestoBaseEngineSpec):
engine = "trino"
engine_name = "Trino"
allows_alias_to_source_column = False
+ # OAuth 2.0
+ supports_oauth2 = True
+ oauth2_exception = TrinoAuthError
+ oauth2_token_request_type = "data"
+
@classmethod
def get_extra_table_metadata(
cls,
@@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
# Set principal_username=$effective_username
if backend_name == "trino" and username is not None:
connect_args["user"] = username
+ if access_token is not None:
+ http_session = requests.Session()
+ http_session.headers.update({"Authorization": f"Bearer
{access_token}"})
+ connect_args["http_session"] = http_session
@classmethod
def get_url_for_impersonation(
@@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
"""
Return a modified URL with the username set.
+ :param access_token: Personal access token for OAuth2
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
@@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_result: dict[str, Any] = {}
execute_event = threading.Event()
+ @copy_current_request_context
def _execute(
results: dict[str, Any],
event: threading.Event,
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 3a850e0acb..c3c40cd31a 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
# expired access token.
token_request_uri: str
+ # Not all identity providers expect json. Keycloak expects a form encoded
request,
+ # which in the `requests` package context means using the `data` param,
not `json`.
+ request_content_type: str
+
class OAuth2TokenResponse(TypedDict, total=False):
"""
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index b889ef83c5..95db2921f6 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING
import backoff
import jwt
from flask import current_app, url_for
-from marshmallow import EXCLUDE, fields, post_load, Schema
+from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
@@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
+ request_content_type = fields.String(
+ required=False,
+ load_default=lambda: "json",
+ validate=validate.OneOf(["json", "data"]),
+ )
diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py
b/tests/unit_tests/db_engine_specs/test_gsheets.py
index 5d2ddb807b..4e17054db9 100644
--- a/tests/unit_tests/db_engine_specs/test_gsheets.py
+++ b/tests/unit_tests/db_engine_specs/test_gsheets.py
@@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"authorization_request_uri":
"https://accounts.google.com/o/oauth2/v2/auth",
"token_request_uri": "https://oauth2.googleapis.com/token",
+ "request_content_type": "json",
}
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py
b/tests/unit_tests/db_engine_specs/test_trino.py
index 5a32cd0504..b616adfcf1 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
-from superset.superset_typing import ResultSetColumnType, SQLAColumnType,
SQLType
+from superset.superset_typing import (
+ OAuth2ClientConfig,
+ ResultSetColumnType,
+ SQLAColumnType,
+ SQLType,
+)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
@@ -421,21 +426,23 @@ def test_execute_with_cursor_in_parallel(app, mocker:
MockerFixture):
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id
- mock_cursor.execute.side_effect = _mock_execute
- with patch.dict(
- "superset.config.DISALLOWED_SQL_FUNCTIONS",
- {},
- clear=True,
- ):
- TrinoEngineSpec.execute_with_cursor(
- cursor=mock_cursor,
- sql="SELECT 1 FROM foo",
- query=mock_query,
- )
+ with app.test_request_context("/some/place/"):
+ mock_cursor.execute.side_effect = _mock_execute
- mock_query.set_extra_json_key.assert_called_once_with(
- key=QUERY_CANCEL_KEY, value=query_id
- )
+ with patch.dict(
+ "superset.config.DISALLOWED_SQL_FUNCTIONS",
+ {},
+ clear=True,
+ ):
+ TrinoEngineSpec.execute_with_cursor(
+ cursor=mock_cursor,
+ sql="SELECT 1 FROM foo",
+ query=mock_query,
+ )
+
+ mock_query.set_extra_json_key.assert_called_once_with(
+ key=QUERY_CANCEL_KEY, value=query_id
+ )
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
@@ -446,23 +453,25 @@ def test_execute_with_cursor_app_context(app, mocker:
MockerFixture):
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
- g.some_value = "some_value"
def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"
- with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
- with patch.dict(
- "superset.config.DISALLOWED_SQL_FUNCTIONS",
- {},
- clear=True,
- ):
- TrinoEngineSpec.execute_with_cursor(
- cursor=mock_cursor,
- sql="SELECT 1 FROM foo",
- query=mock_query,
- )
+ with app.test_request_context("/some/place/"):
+ g.some_value = "some_value"
+
+ with patch.object(TrinoEngineSpec, "execute",
side_effect=_mock_execute):
+ with patch.dict(
+ "superset.config.DISALLOWED_SQL_FUNCTIONS",
+ {},
+ clear=True,
+ ):
+ TrinoEngineSpec.execute_with_cursor(
+ cursor=mock_cursor,
+ sql="SELECT 1 FROM foo",
+ query=mock_query,
+ )
def test_get_columns(mocker: MockerFixture):
@@ -784,3 +793,57 @@ def test_where_latest_partition(
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)
+
+
[email protected]
+def oauth2_config() -> OAuth2ClientConfig:
+ """
+ Config for Trino OAuth2.
+ """
+ return {
+ "id": "trino",
+ "secret": "very-secret",
+ "scope": "",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "authorization_request_uri":
"https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
+ "token_request_uri":
"https://trino.auth.server.example/master/protocol/openid-connect/token",
+ "request_content_type": "data",
+ }
+
+
+def test_get_oauth2_token(
+ mocker: MockerFixture,
+ oauth2_config: OAuth2ClientConfig,
+) -> None:
+ """
+ Test `get_oauth2_token`.
+ """
+ from superset.db_engine_specs.trino import TrinoEngineSpec
+
+ requests = mocker.patch("superset.db_engine_specs.base.requests")
+ requests.post().json.return_value = {
+ "access_token": "access-token",
+ "expires_in": 3600,
+ "scope": "scope",
+ "token_type": "Bearer",
+ "refresh_token": "refresh-token",
+ }
+
+ assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
+ "access_token": "access-token",
+ "expires_in": 3600,
+ "scope": "scope",
+ "token_type": "Bearer",
+ "refresh_token": "refresh-token",
+ }
+ requests.post.assert_called_with(
+
"https://trino.auth.server.example/master/protocol/openid-connect/token",
+ data={
+ "code": "code",
+ "client_id": "trino",
+ "client_secret": "very-secret",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "grant_type": "authorization_code",
+ },
+ timeout=30.0,
+ )
diff --git a/tests/unit_tests/models/core_test.py
b/tests/unit_tests/models/core_test.py
index 452cbb6f56..1dff4784ec 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"token_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
+ "request_content_type": "json",
}