ashb commented on code in PR #46981: URL: https://github.com/apache/airflow/pull/46981#discussion_r1969495572
########## airflow/security/tokens.py: ########## @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import os +import time +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import attrs +import httpx +import jwt +import structlog +from asgiref.sync import async_to_sync +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +from airflow.utils import timezone + +if TYPE_CHECKING: + from jwt.algorithms import AllowedKeys, AllowedPrivateKeys + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "InvalidClaimError", + "JWKS", + "JWTGenerator", + "JWTValidator", + "generate_private_key", + "get_signing_key", + "key_to_pem", + "key_to_jwk_dict", +] + + +class InvalidClaimError(ValueError): + """Raised when a claim in the JWT is invalid.""" + + def __init__(self, claim: str): + super().__init__(f"Invalid claim: {claim}") + + +def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): + """Convert a public or private key into a valid JWKS dict.""" + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey + from jwt.algorithms import OKPAlgorithm, RSAAlgorithm + + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): + key = key.public_key() + + if isinstance(key, RSAPublicKey): + jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True) + + elif isinstance(key, Ed25519PublicKey): + jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True) + else: + raise ValueError(f"Unknown key object {type(key)}") + + if not kid: + kid = thumbprint(jwk_dict) + + jwk_dict["kid"] = kid + + return jwk_dict + + +def _guess_best_algorithm(key: AllowedPrivateKeys): + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + if isinstance(key, RSAPrivateKey): + return "RS512" + elif isinstance(key, Ed25519PrivateKey): + return "EdDSA" + else: + raise ValueError(f"Unknown key object {type(key)}") + + [email protected](repr=False) +class JWKS: + """A class to fetch and sync a set of JSON Web Keys.""" + + url: str + fetched_at: float = 0 + last_fetch_attempt_at: float = 0 + + client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient) + + _jwks: jwt.PyJWKSet | None = None + refresh_jwks: bool = True + refresh_interval_secs: int = 3600 + refresh_retry_interval_secs: int = 10 + Review Comment: We don't _want_ to close the client, we want to keep it around for background refreshes of the JWKS document. If we close it we can't do that anymore. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
