sunank200 commented on code in PR #46981: URL: https://github.com/apache/airflow/pull/46981#discussion_r1967195103
########## airflow/security/tokens.py: ########## @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import os +import time +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import attrs +import httpx +import jwt +import structlog +from asgiref.sync import async_to_sync +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +from airflow.utils import timezone + +if TYPE_CHECKING: + from jwt.algorithms import AllowedKeys, AllowedPrivateKeys + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "InvalidClaimError", + "JWKS", + "JWTGenerator", + "JWTValidator", + "generate_private_key", + "get_signing_key", + "key_to_pem", + "key_to_jwk_dict", +] + + +class InvalidClaimError(ValueError): + """Raised when a claim in the JWT is invalid.""" + + def __init__(self, claim: str): + super().__init__(f"Invalid claim: {claim}") + + +def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): + """Convert a public or private key into a valid JWKS dict.""" + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey + from jwt.algorithms import OKPAlgorithm, RSAAlgorithm + + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): + key = key.public_key() + + if isinstance(key, RSAPublicKey): + jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True) + + elif isinstance(key, Ed25519PublicKey): + jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True) + else: + raise ValueError(f"Unknown key object {type(key)}") + + if not kid: + kid = thumbprint(jwk_dict) + + jwk_dict["kid"] = kid + + return jwk_dict + + +def _guess_best_algorithm(key: AllowedPrivateKeys): + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + if isinstance(key, RSAPrivateKey): + return "RS512" + elif isinstance(key, Ed25519PrivateKey): + return "EdDSA" + else: + raise ValueError(f"Unknown key object {type(key)}") + + [email protected](repr=False) +class JWKS: + """A class to fetch and sync a set of JSON Web Keys.""" + + url: str + fetched_at: float = 0 + last_fetch_attempt_at: float = 0 + + client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient) + + _jwks: jwt.PyJWKSet | None = None + refresh_jwks: bool = True + refresh_interval_secs: int = 3600 + refresh_retry_interval_secs: int = 10 + Review Comment: ```suggestion async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): await self.client.aclose() async def close(self) -> None: """Close the HTTP client explicitly.""" await self.client.aclose() ``` Is this not needed? Once it is no longer needed, there is no explicit cleanup (like closing the client). ########## airflow/auth/managers/base_auth_manager.py: ########## @@ -457,14 +457,28 @@ def register_views(self) -> None: """Register views specific to the auth manager.""" @staticmethod - def _get_token_signer(): + def _get_token_signer() -> JWTGenerator: """ Return the signer used to sign JWT token. :meta private: """ - return JWTSigner( - secret_key=get_signing_key("api", "auth_jwt_secret"), - expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), + return JWTGenerator( + secret_key=get_signing_key("api_auth", "jwt_secret"), + valid_for=conf.getint("api", "auth_jwt_expiration_time"), + audience="front-apis", + ) + + @staticmethod + def _get_token_validator() -> JWTValidator: + """ + Return the signer used to sign JWT token. + + :meta private: + """ + return JWTValidator( + # issuer=conf.get("api_auth", "jwt_iussuer"), Review Comment: ```suggestion issuer=conf.get("api_auth", "jwt_iussuer"), ``` ########## airflow/security/tokens.py: ########## @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import os +import time +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import attrs +import httpx +import jwt +import structlog +from asgiref.sync import async_to_sync +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +from airflow.utils import timezone + +if TYPE_CHECKING: + from jwt.algorithms import AllowedKeys, AllowedPrivateKeys + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "InvalidClaimError", + "JWKS", + "JWTGenerator", + "JWTValidator", + "generate_private_key", + "get_signing_key", + "key_to_pem", + "key_to_jwk_dict", +] + + +class InvalidClaimError(ValueError): + """Raised when a claim in the JWT is invalid.""" + + def __init__(self, claim: str): + super().__init__(f"Invalid claim: {claim}") + + +def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): + """Convert a public or private key into a valid JWKS dict.""" + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey + from jwt.algorithms import OKPAlgorithm, RSAAlgorithm + + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): + key = key.public_key() + + if isinstance(key, RSAPublicKey): + jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True) + + elif isinstance(key, Ed25519PublicKey): + jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True) + else: + raise ValueError(f"Unknown key object {type(key)}") + + if not kid: + kid = thumbprint(jwk_dict) + + jwk_dict["kid"] = kid + + return jwk_dict + + +def _guess_best_algorithm(key: AllowedPrivateKeys): + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + if isinstance(key, RSAPrivateKey): + return "RS512" + elif isinstance(key, Ed25519PrivateKey): + return "EdDSA" + else: + raise ValueError(f"Unknown key object {type(key)}") + + [email protected](repr=False) +class JWKS: + """A class to fetch and sync a set of JSON Web Keys.""" + + url: str + fetched_at: float = 0 + last_fetch_attempt_at: float = 0 + + client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient) + + _jwks: jwt.PyJWKSet | None = None + refresh_jwks: bool = True + refresh_interval_secs: int = 3600 + refresh_retry_interval_secs: int = 10 + + def __repr__(self) -> str: + return f"JWKS(url={self.url}, fetched_at={self.fetched_at})" + + @classmethod + def from_private_key(cls, **keys: AllowedPrivateKeys): + obj = cls(url=os.devnull) + + obj._jwks = jwt.PyJWKSet([key_to_jwk_dict(key, kid) for kid, key in keys.items()]) + return obj + + async def fetch_jwks(self) -> None: + if not self._should_fetch_jwks(): + return + if self.url.startswith("http"): + data = await self._fetch_remote_jwks() + else: + data = self._fetch_local_jwks() + + if not data: + return + + self._jwks = jwt.PyJWKSet.from_dict(data) + log.debug("Fetched JWKS", url=self.url, keys=len(self._jwks.keys)) + + async def _fetch_remote_jwks(self) -> dict[str, Any] | None: + try: + log.debug( + "Fetching JWKS", + url=self.url, + last_fetched_secs_ago=int(time.monotonic() - self.fetched_at) if self.fetched_at else None, + ) + if TYPE_CHECKING: + assert self.url + self.last_fetch_attempt_at = int(time.monotonic()) + response = await self.client.get(self.url) + response.raise_for_status() + self.fetched_at = int(time.monotonic()) + await response.aread() + await response.aclose() + return response.json() + except Exception: + log.exception("Failed to fetch remote JWKS", url=self.url) + return None + + def _fetch_local_jwks(self) -> dict[str, Any] | None: + try: + with open(self.url) as jwks_file: + content = json.load(jwks_file) + self.fetched_at = int(time.monotonic()) + return content + except Exception: + log.exception("Failed to read local JWKS", url=self.url) + return None + + def _should_fetch_jwks(self) -> bool: + """ + Check if we need to fetch the JWKS based on the last fetch time and the refresh interval. + + If the JWKS URL is local, we only fetch it once. For remote JWKS URLs we fetch it based + on the refresh interval if refreshing has been enabled with a minimum interval between + attempts. The fetcher functions set the fetched_at timestamp to the current monotonic time + when the JWKS is fetched. + """ + if not self.url.startswith("http"): + # Fetch local JWKS only if not already loaded + # This could be improved in future by looking at mtime of file. + return not self._jwks + # For remote fetches we check if the JWKS is not loaded (fetched_at = 0) or if the last fetch was more than + # refresh_interval_secs ago and the last fetch attempt was more than refresh_retry_interval_secs ago + now = time.monotonic() + return self.refresh_jwks and ( + not self._jwks + or ( + self.fetched_at == 0 + or ( + now - self.fetched_at > self.refresh_interval_secs + and now - self.last_fetch_attempt_at > self.refresh_retry_interval_secs + ) + ) + ) Review Comment: ```suggestion refresh_needed = not self._jwks or self.fetched_at == 0 time_for_refresh = now - self.fetched_at > self.refresh_interval_secs time_for_retry = now - self.last_fetch_attempt_at > self.refresh_retry_interval_secs return self.refresh_jwks and (refresh_needed or (time_for_refresh and time_for_retry)) ``` ########## airflow/auth/managers/base_auth_manager.py: ########## @@ -457,14 +457,28 @@ def register_views(self) -> None: """Register views specific to the auth manager.""" @staticmethod - def _get_token_signer(): + def _get_token_signer() -> JWTGenerator: """ Return the signer used to sign JWT token. :meta private: """ - return JWTSigner( - secret_key=get_signing_key("api", "auth_jwt_secret"), - expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), + return JWTGenerator( + secret_key=get_signing_key("api_auth", "jwt_secret"), + valid_for=conf.getint("api", "auth_jwt_expiration_time"), + audience="front-apis", + ) + + @staticmethod + def _get_token_validator() -> JWTValidator: + """ + Return the signer used to sign JWT token. + + :meta private: + """ + return JWTValidator( + # issuer=conf.get("api_auth", "jwt_iussuer"), Review Comment: Or it can be None, maybe? ########## airflow/security/tokens.py: ########## @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import os +import time +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import attrs +import httpx +import jwt +import structlog +from asgiref.sync import async_to_sync +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +from airflow.utils import timezone + +if TYPE_CHECKING: + from jwt.algorithms import AllowedKeys, AllowedPrivateKeys + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "InvalidClaimError", + "JWKS", + "JWTGenerator", + "JWTValidator", + "generate_private_key", + "get_signing_key", + "key_to_pem", + "key_to_jwk_dict", +] + + +class InvalidClaimError(ValueError): + """Raised when a claim in the JWT is invalid.""" + + def __init__(self, claim: str): + super().__init__(f"Invalid claim: {claim}") + + +def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): + """Convert a public or private key into a valid JWKS dict.""" + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey + from jwt.algorithms import OKPAlgorithm, RSAAlgorithm + + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): + key = key.public_key() + + if isinstance(key, RSAPublicKey): + jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True) + + elif isinstance(key, Ed25519PublicKey): + jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True) + else: + raise ValueError(f"Unknown key object {type(key)}") + + if not kid: + kid = thumbprint(jwk_dict) + + jwk_dict["kid"] = kid + + return jwk_dict + + +def _guess_best_algorithm(key: AllowedPrivateKeys): + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + if isinstance(key, RSAPrivateKey): + return "RS512" + elif isinstance(key, Ed25519PrivateKey): + return "EdDSA" + else: + raise ValueError(f"Unknown key object {type(key)}") + + [email protected](repr=False) +class JWKS: + """A class to fetch and sync a set of JSON Web Keys.""" + + url: str + fetched_at: float = 0 + last_fetch_attempt_at: float = 0 + + client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient) + + _jwks: jwt.PyJWKSet | None = None + refresh_jwks: bool = True + refresh_interval_secs: int = 3600 + refresh_retry_interval_secs: int = 10 + + def __repr__(self) -> str: + return f"JWKS(url={self.url}, fetched_at={self.fetched_at})" + + @classmethod + def from_private_key(cls, **keys: AllowedPrivateKeys): + obj = cls(url=os.devnull) + + obj._jwks = jwt.PyJWKSet([key_to_jwk_dict(key, kid) for kid, key in keys.items()]) + return obj + + async def fetch_jwks(self) -> None: + if not self._should_fetch_jwks(): + return + if self.url.startswith("http"): + data = await self._fetch_remote_jwks() + else: + data = self._fetch_local_jwks() + + if not data: + return + + self._jwks = jwt.PyJWKSet.from_dict(data) + log.debug("Fetched JWKS", url=self.url, keys=len(self._jwks.keys)) + + async def _fetch_remote_jwks(self) -> dict[str, Any] | None: + try: + log.debug( + "Fetching JWKS", + url=self.url, + last_fetched_secs_ago=int(time.monotonic() - self.fetched_at) if self.fetched_at else None, + ) + if TYPE_CHECKING: + assert self.url + self.last_fetch_attempt_at = int(time.monotonic()) + response = await self.client.get(self.url) + response.raise_for_status() + self.fetched_at = int(time.monotonic()) + await response.aread() + await response.aclose() + return response.json() + except Exception: + log.exception("Failed to fetch remote JWKS", url=self.url) + return None + + def _fetch_local_jwks(self) -> dict[str, Any] | None: + try: + with open(self.url) as jwks_file: + content = json.load(jwks_file) + self.fetched_at = int(time.monotonic()) + return content + except Exception: + log.exception("Failed to read local JWKS", url=self.url) + return None + + def _should_fetch_jwks(self) -> bool: + """ + Check if we need to fetch the JWKS based on the last fetch time and the refresh interval. + + If the JWKS URL is local, we only fetch it once. For remote JWKS URLs we fetch it based + on the refresh interval if refreshing has been enabled with a minimum interval between + attempts. The fetcher functions set the fetched_at timestamp to the current monotonic time + when the JWKS is fetched. + """ + if not self.url.startswith("http"): + # Fetch local JWKS only if not already loaded + # This could be improved in future by looking at mtime of file. + return not self._jwks + # For remote fetches we check if the JWKS is not loaded (fetched_at = 0) or if the last fetch was more than + # refresh_interval_secs ago and the last fetch attempt was more than refresh_retry_interval_secs ago + now = time.monotonic() + return self.refresh_jwks and ( + not self._jwks + or ( + self.fetched_at == 0 + or ( + now - self.fetched_at > self.refresh_interval_secs + and now - self.last_fetch_attempt_at > self.refresh_retry_interval_secs + ) + ) + ) + + async def get_key(self, kid: str) -> jwt.PyJWK: + """Fetch the JWKS and find the matching key for the token.""" + await self.fetch_jwks() + + if self._jwks: + return self._jwks[kid] + + # It didn't load! + raise KeyError(f"Key ID {kid} not found in keyset") + + +def _conf_factory(section, key, **kwargs): + def factory() -> str: + from airflow.configuration import conf + + return conf.get(section, key, **kwargs, suppress_warnings=True) # type: ignore[return-value] + + return factory + + +def _conf_list_factory(section, key, first_only: bool = False, **kwargs): + def factory() -> list[str] | str: + from airflow.configuration import conf + + val = conf.getlist(section, key, **kwargs, suppress_warnings=True) + + if first_only and val: + return val[0] + return val + + return factory + + +def _sec_to_timedelta(what: int | timedelta) -> timedelta: + if isinstance(what, timedelta): + return what + return timedelta(seconds=what) + + +def _to_list(val: str | list[str]) -> list[str]: + if isinstance(val, str): + val = [val] + return val + + [email protected](kw_only=True) +class JWTValidator: + """ + Validate the claims and validitory of a JWT. + + This will either validate the JWT is signed with the symmetric key if ``secret_key`` is passed, or else + that it is signed by one of the public keys in the keyset in ``jwks`` attribute. + """ + + jwks: JWKS | None = None + secret_key: str | None = attrs.field(repr=False, default=None, converter=lambda v: None if v == "" else v) + issuer: str | list[str] | None = attrs.field( + factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None) + ) + # By default, we just validate these + required_claims: frozenset[str] = frozenset({"exp", "iat", "nbf"}) + audience: str | Sequence[str] + algorithm: list[str] = attrs.field(default=["GUESS"], converter=_to_list) + + leeway: timedelta = attrs.field(default=timedelta(seconds=5), converter=_sec_to_timedelta) + + def __attrs_post_init__(self): + if not (self.jwks is None) ^ (self.secret_key is None): + raise ValueError("Exactly one of priavte_key and secret_key must be specified") Review Comment: ```suggestion raise ValueError("Exactly one of private_key and secret_key must be specified") ``` ########## airflow/security/tokens.py: ########## @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import os +import time +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import attrs +import httpx +import jwt +import structlog +from asgiref.sync import async_to_sync +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +from airflow.utils import timezone + +if TYPE_CHECKING: + from jwt.algorithms import AllowedKeys, AllowedPrivateKeys + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "InvalidClaimError", + "JWKS", + "JWTGenerator", + "JWTValidator", + "generate_private_key", + "get_signing_key", + "key_to_pem", + "key_to_jwk_dict", +] + + +class InvalidClaimError(ValueError): + """Raised when a claim in the JWT is invalid.""" + + def __init__(self, claim: str): + super().__init__(f"Invalid claim: {claim}") + + +def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): + """Convert a public or private key into a valid JWKS dict.""" + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey + from jwt.algorithms import OKPAlgorithm, RSAAlgorithm + + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): + key = key.public_key() + + if isinstance(key, RSAPublicKey): + jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True) + + elif isinstance(key, Ed25519PublicKey): + jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True) + else: + raise ValueError(f"Unknown key object {type(key)}") + + if not kid: + kid = thumbprint(jwk_dict) + + jwk_dict["kid"] = kid + + return jwk_dict + + +def _guess_best_algorithm(key: AllowedPrivateKeys): + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + if isinstance(key, RSAPrivateKey): + return "RS512" + elif isinstance(key, Ed25519PrivateKey): + return "EdDSA" + else: + raise ValueError(f"Unknown key object {type(key)}") + + [email protected](repr=False) +class JWKS: + """A class to fetch and sync a set of JSON Web Keys.""" + + url: str + fetched_at: float = 0 + last_fetch_attempt_at: float = 0 + + client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient) + + _jwks: jwt.PyJWKSet | None = None + refresh_jwks: bool = True + refresh_interval_secs: int = 3600 + refresh_retry_interval_secs: int = 10 + + def __repr__(self) -> str: + return f"JWKS(url={self.url}, fetched_at={self.fetched_at})" + + @classmethod + def from_private_key(cls, **keys: AllowedPrivateKeys): + obj = cls(url=os.devnull) + + obj._jwks = jwt.PyJWKSet([key_to_jwk_dict(key, kid) for kid, key in keys.items()]) + return obj + + async def fetch_jwks(self) -> None: + if not self._should_fetch_jwks(): + return + if self.url.startswith("http"): + data = await self._fetch_remote_jwks() + else: + data = self._fetch_local_jwks() + + if not data: + return + + self._jwks = jwt.PyJWKSet.from_dict(data) + log.debug("Fetched JWKS", url=self.url, keys=len(self._jwks.keys)) + + async def _fetch_remote_jwks(self) -> dict[str, Any] | None: + try: + log.debug( + "Fetching JWKS", + url=self.url, + last_fetched_secs_ago=int(time.monotonic() - self.fetched_at) if self.fetched_at else None, + ) + if TYPE_CHECKING: + assert self.url + self.last_fetch_attempt_at = int(time.monotonic()) + response = await self.client.get(self.url) + response.raise_for_status() + self.fetched_at = int(time.monotonic()) + await response.aread() + await response.aclose() + return response.json() + except Exception: + log.exception("Failed to fetch remote JWKS", url=self.url) + return None + + def _fetch_local_jwks(self) -> dict[str, Any] | None: + try: + with open(self.url) as jwks_file: + content = json.load(jwks_file) + self.fetched_at = int(time.monotonic()) + return content + except Exception: + log.exception("Failed to read local JWKS", url=self.url) + return None + + def _should_fetch_jwks(self) -> bool: + """ + Check if we need to fetch the JWKS based on the last fetch time and the refresh interval. + + If the JWKS URL is local, we only fetch it once. For remote JWKS URLs we fetch it based + on the refresh interval if refreshing has been enabled with a minimum interval between + attempts. The fetcher functions set the fetched_at timestamp to the current monotonic time + when the JWKS is fetched. + """ + if not self.url.startswith("http"): + # Fetch local JWKS only if not already loaded + # This could be improved in future by looking at mtime of file. + return not self._jwks + # For remote fetches we check if the JWKS is not loaded (fetched_at = 0) or if the last fetch was more than + # refresh_interval_secs ago and the last fetch attempt was more than refresh_retry_interval_secs ago + now = time.monotonic() + return self.refresh_jwks and ( + not self._jwks + or ( + self.fetched_at == 0 + or ( + now - self.fetched_at > self.refresh_interval_secs + and now - self.last_fetch_attempt_at > self.refresh_retry_interval_secs + ) + ) + ) + + async def get_key(self, kid: str) -> jwt.PyJWK: + """Fetch the JWKS and find the matching key for the token.""" + await self.fetch_jwks() + + if self._jwks: + return self._jwks[kid] + + # It didn't load! + raise KeyError(f"Key ID {kid} not found in keyset") + + +def _conf_factory(section, key, **kwargs): + def factory() -> str: + from airflow.configuration import conf + + return conf.get(section, key, **kwargs, suppress_warnings=True) # type: ignore[return-value] + + return factory + + +def _conf_list_factory(section, key, first_only: bool = False, **kwargs): + def factory() -> list[str] | str: + from airflow.configuration import conf + + val = conf.getlist(section, key, **kwargs, suppress_warnings=True) + + if first_only and val: + return val[0] + return val + + return factory + + +def _sec_to_timedelta(what: int | timedelta) -> timedelta: + if isinstance(what, timedelta): + return what + return timedelta(seconds=what) + + +def _to_list(val: str | list[str]) -> list[str]: + if isinstance(val, str): + val = [val] + return val + + [email protected](kw_only=True) +class JWTValidator: + """ + Validate the claims and validitory of a JWT. + + This will either validate the JWT is signed with the symmetric key if ``secret_key`` is passed, or else + that it is signed by one of the public keys in the keyset in ``jwks`` attribute. + """ + + jwks: JWKS | None = None + secret_key: str | None = attrs.field(repr=False, default=None, converter=lambda v: None if v == "" else v) + issuer: str | list[str] | None = attrs.field( + factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None) + ) + # By default, we just validate these + required_claims: frozenset[str] = frozenset({"exp", "iat", "nbf"}) + audience: str | Sequence[str] + algorithm: list[str] = attrs.field(default=["GUESS"], converter=_to_list) + + leeway: timedelta = attrs.field(default=timedelta(seconds=5), converter=_sec_to_timedelta) + + def __attrs_post_init__(self): + if not (self.jwks is None) ^ (self.secret_key is None): + raise ValueError("Exactly one of priavte_key and secret_key must be specified") + + if self.algorithm == ["GUESS"]: + if self.jwks: + # TODO: We could probably populate this from the jwks document? + raise ValueError("Cannot guess the algorithm when using JWKS") + else: + self.algorithm = ["HS512"] + + def _get_kid_from_header(self, unvalidated: str) -> str: + header = jwt.get_unverified_header(unvalidated) + if "kid" not in header: + raise ValueError("Missing 'kid' in token header") + return header["kid"] + + async def _get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK: + if self.secret_key: + return self.secret_key + + if TYPE_CHECKING: + assert self.jwks is not None + + kid = self._get_kid_from_header(unvalidated) + return await self.jwks.get_key(kid) + + def validated_claims( + self, unvalidated: str, extra_claims: dict[str, Any] | None = None + ) -> dict[str, Any]: + return async_to_sync(self.avalidated_claims)(unvalidated, extra_claims) + + async def avalidated_claims( + self, unvalidated: str, extra_claims: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Decode the JWT token, returning the validated claims or raising an exception.""" + key = await self._get_validation_key(unvalidated) + claims = jwt.decode( + unvalidated, + key, + audience=self.audience, + issuer=self.issuer, + options={"require": self.required_claims}, + algorithms=self.algorithm, + leeway=self.leeway, + ) + + # Validate additional claims if provided + if extra_claims: + for claim, expected_value in extra_claims.items(): + if expected_value["essential"] and ( + claim not in claims or claims[claim] != expected_value["value"] + ): + raise InvalidClaimError(claim) + + return claims + + +def _pem_to_key(pem_data: str | bytes | AllowedPrivateKeys) -> AllowedPrivateKeys: + if isinstance(pem_data, str): + pem_data = pem_data.encode() + elif not isinstance(pem_data, bytes): + # Assume it's already a key object + return pem_data + + return load_pem_private_key(pem_data, password=None) # type: ignore[return-value] + + [email protected](repr=False, kw_only=True) +class JWTGenerator: + """Generate JWT tokens.""" + + _private_key: AllowedPrivateKeys | None = attrs.field( + repr=False, alias="private_key", converter=_pem_to_key + ) + """ + Private key to sign generated tokens. + + Should be either a private key object from the cryptography module, or a PEM-encoded byte string + """ + _secret_key: str | None = attrs.field( + repr=False, + alias="secret_key", + default=None, + converter=lambda v: None if v == "" else v, + ) + """A pre-shared secret key to sign tokens with symmetric encryption""" + + kid: str = attrs.field() + valid_for: timedelta = attrs.field(converter=_sec_to_timedelta) + audience: str + issuer: str | list[str] | None = attrs.field( + factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True, fallback=None) + ) + algorithm: str = attrs.field(factory=_conf_factory("api_auth", "jwt_algorithm", fallback="GUESS")) + + @_private_key.default + def _load_key_from_configured_file(self) -> bytes | None: + from airflow.configuration import conf + + path = conf.get("api_auth", "jwt_private_key_path", fallback=None) + if not path: + return None + + with open(path, mode="rb") as fh: + return fh.read() + + @kid.default + def _generate_kid(self): + if not self._private_key: + return "not-used" + + info = key_to_jwk_dict(self._private_key) + return info["kid"] + + def __attrs_post_init__(self): + if not (self._private_key is None) ^ (self._secret_key is None): + raise ValueError("Exactly one of priavte_key and secret_key must be specified") + + if self.algorithm == "GUESS": + if self._private_key: + ... + else: + self.algorithm = "HS512" + + @property + def signing_arg(self) -> AllowedPrivateKeys | str: + if callable(self._private_key): + return self._private_key() + if self._private_key: + return self._private_key + if TYPE_CHECKING: + # Already handled at in post_init + assert self._secret_key + return self._secret_key + + def generate( + self, subject: str, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None + ) -> str: + """Generate a signed JWT for the subject.""" + now = datetime.now(tz=timezone.utc) + claims = { + "iss": self.issuer, + "aud": self.audience, + "sub": subject, + "nbf": int(now.timestamp()), + "exp": int((now + self.valid_for).timestamp()), + "iat": int(now.timestamp()), + } + if extras is not None: + claims.update(extras) + headers = {"alg": self.algorithm, **(headers or {})} + if self._private_key: + headers["kid"] = self.kid + return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + + +# @attrs.define(repr=False) +# class TaskJWTGenerator(JWTGenerator): +# issuer: str = attrs.field(factory=_default_issuer) +# audience: str = attrs.field( +# factory=_conf_factory("task_execution_api", "jwt_audience", fallback="urn:airflow.apache.org:task") +# ) +# algorithm: str = attrs.field( +# factory=_conf_factory("task_execution_api", "jwt_algorithm", default="EdDSA") +# ) + + +def generate_private_key(key_type: str = "RSA", key_size: int = 2048): + """ + Generate a valid private key for testing. + + Args: + key_type (str): Type of key to generate. Can be "RSA" or "Ed25516". Defaults to "RSA". + key_size (int): Size of the key in bits. Only applicable for RSA keys. Defaults to 2048. + + Returns: + tuple: A tuple containing the private key in PEM format and the corresponding public key in PEM format. + """ Review Comment: Is this returning tuple always? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
