betodealmeida commented on code in PR #38469:
URL: https://github.com/apache/superset/pull/38469#discussion_r2989380675
##########
superset/config.py:
##########
@@ -2111,6 +2111,13 @@ def EMAIL_HEADER_MUTATOR( # pylint:
disable=invalid-name,unused-argument # noq
# Timeout when fetching access and refresh tokens.
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
+# Map Superset database_name values to OAUTH_PROVIDERS names.
+# For databases with a matching database_name, the upstream login token will be
+# forwarded instead of triggering a separate database OAuth2 dance.
+# Requires `save_token: True` in the corresponding OAUTH_PROVIDERS entry.
+# Example: {"trino_prod": "my_keycloak_prov", "trino_staging":
"my_keycloak_prov"}
+DATABASE_OAUTH2_UPSTREAM_PROVIDERS: dict[str, str] = {}
Review Comment:
I'm worried this might be an anti-pattern — having hard coded configuration
that is related to dynamic assets that live in the metadata database. Things
would break if you were to change the database name (which can be done from the
UI, and would require a re-deployment), and we might end up sending the token
to the wrong database accidentally or through malicious intent.
For `DATABASE_OAUTH2_CLIENTS` we do store a configuration mapping the DB
engine to the client information, but we quickly realized that's not a good
solution, so we implemented support for storing the client information in the
database itself. We also have plans to add models for OAuth2 clients, so that a
client can be configured once and reused in different databases.
I think it would be better if this information lived in the database
configuration instead, since it's tied to the lifetime of the database. We can
have a key `oauth2_upstream_provider` that stores the name of the provider, and
when present we reused the saved token, if available.
##########
superset/models/core.py:
##########
@@ -509,20 +509,30 @@ def _get_sqla_engine( # pylint: disable=too-many-locals
# noqa: C901
if user and user.email:
effective_username = user.email.split("@")[0]
- oauth2_config = self.get_oauth2_config()
- access_token = (
- get_oauth2_access_token(
- oauth2_config,
- self.id,
- g.user.id,
- self.db_engine_spec,
- )
- if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
- else None
- )
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s",
str(masked_url))
+ # Check if this database has an upstream login provider configured.
+ # If so, use the saved login token instead of a separate database
OAuth2 dance.
+ upstream_providers =
app.config.get("DATABASE_OAUTH2_UPSTREAM_PROVIDERS", {})
+ upstream_provider = upstream_providers.get(self.database_name)
+ if upstream_provider and hasattr(g, "user") and hasattr(g.user, "id"):
+ from superset.utils.oauth2 import get_upstream_provider_token
+
+ access_token = get_upstream_provider_token(upstream_provider,
g.user.id)
Review Comment:
I think we could push this logic to `get_oauth2_access_token`, so we keep
`superset/models/core.py` clean.
##########
superset/utils/oauth2.py:
##########
@@ -276,3 +280,100 @@ def check_for_oauth2(database: Database) ->
Iterator[None]:
if database.is_oauth2_enabled() and
database.db_engine_spec.needs_oauth2(ex):
database.db_engine_spec.start_oauth2_dance(database)
raise
+
+
+def save_user_provider_token(
+ user_id: int,
+ provider: str,
+ token_response: dict[str, Any],
+) -> None:
+ """
+ Upsert an UpstreamOAuthToken row for the given user + provider.
+ """
+ from superset.models.core import UpstreamOAuthToken
+
+ token: UpstreamOAuthToken | None = (
+ db.session.query(UpstreamOAuthToken)
+ .filter_by(user_id=user_id, provider=provider)
+ .one_or_none()
+ )
+ if token is None:
+ token = UpstreamOAuthToken(user_id=user_id, provider=provider)
+
+ token.access_token = token_response.get("access_token")
+ expires_in = token_response.get("expires_in")
+ token.access_token_expiration = (
+ datetime.now() + timedelta(seconds=expires_in) if expires_in else None
+ )
+ token.refresh_token = token_response.get("refresh_token")
+ db.session.add(token)
+ db.session.commit()
+
+
+def get_upstream_provider_token(provider: str, user_id: int) -> str | None:
+ """
+ Retrieve a valid access token for the given provider and user.
+
+ If the token is expired and a refresh token exists, attempt to refresh it.
+ Returns None if no valid token is available.
+ """
+ from superset.models.core import UpstreamOAuthToken
+
+ token: UpstreamOAuthToken | None = (
+ db.session.query(UpstreamOAuthToken)
+ .filter_by(user_id=user_id, provider=provider)
+ .one_or_none()
+ )
+ if token is None:
+ return None
+
+ now = datetime.now()
+ if token.access_token_expiration is None or token.access_token_expiration
> now:
+ return token.access_token
+
+ # Token is expired
+ if token.refresh_token:
+ return _refresh_upstream_provider_token(token, provider)
+
Review Comment:
Yeah, this is an issue. Let's use `DistributedLock` like we do in
`refresh_oauth2_token`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]