ashb commented on code in PR #46981:
URL: https://github.com/apache/airflow/pull/46981#discussion_r1969792517


##########
airflow/security/tokens.py:
##########
@@ -0,0 +1,504 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import os
+import time
+from collections.abc import Sequence
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any
+
+import attrs
+import httpx
+import jwt
+import structlog
+from asgiref.sync import async_to_sync
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.serialization import load_pem_private_key
+
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+    from jwt.algorithms import AllowedKeys, AllowedPrivateKeys
+
+log = structlog.get_logger(logger_name=__name__)
+
+__all__ = [
+    "InvalidClaimError",
+    "JWKS",
+    "JWTGenerator",
+    "JWTValidator",
+    "generate_private_key",
+    "get_signing_key",
+    "key_to_pem",
+    "key_to_jwk_dict",
+]
+
+
+class InvalidClaimError(ValueError):
+    """Raised when a claim in the JWT is invalid."""
+
+    def __init__(self, claim: str):
+        super().__init__(f"Invalid claim: {claim}")
+
+
+def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None):
+    """Convert a public or private key into a valid JWKS dict."""
+    from cryptography.hazmat.primitives.asymmetric.ed25519 import 
Ed25519PrivateKey, Ed25519PublicKey
+    from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, 
RSAPublicKey
+    from jwt.algorithms import OKPAlgorithm, RSAAlgorithm
+
+    if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)):
+        key = key.public_key()
+
+    if isinstance(key, RSAPublicKey):
+        jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True)
+
+    elif isinstance(key, Ed25519PublicKey):
+        jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True)
+    else:
+        raise ValueError(f"Unknown key object {type(key)}")
+
+    if not kid:
+        kid = thumbprint(jwk_dict)
+
+    jwk_dict["kid"] = kid
+
+    return jwk_dict
+
+
+def _guess_best_algorithm(key: AllowedPrivateKeys):
+    from cryptography.hazmat.primitives.asymmetric.ed25519 import 
Ed25519PrivateKey
+    from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
+
+    if isinstance(key, RSAPrivateKey):
+        return "RS512"
+    elif isinstance(key, Ed25519PrivateKey):
+        return "EdDSA"
+    else:
+        raise ValueError(f"Unknown key object {type(key)}")
+
+
[email protected](repr=False)
+class JWKS:
+    """A class to fetch and sync a set of JSON Web Keys."""
+
+    url: str
+    fetched_at: float = 0
+    last_fetch_attempt_at: float = 0
+
+    client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient)
+
+    _jwks: jwt.PyJWKSet | None = None
+    refresh_jwks: bool = True
+    refresh_interval_secs: int = 3600
+    refresh_retry_interval_secs: int = 10
+
+    def __repr__(self) -> str:
+        return f"JWKS(url={self.url}, fetched_at={self.fetched_at})"
+
+    @classmethod
+    def from_private_key(cls, **keys: AllowedPrivateKeys):
+        obj = cls(url=os.devnull)
+
+        obj._jwks = jwt.PyJWKSet([key_to_jwk_dict(key, kid) for kid, key in 
keys.items()])
+        return obj
+
+    async def fetch_jwks(self) -> None:
+        if not self._should_fetch_jwks():
+            return
+        if self.url.startswith("http"):
+            data = await self._fetch_remote_jwks()
+        else:
+            data = self._fetch_local_jwks()
+
+        if not data:
+            return
+
+        self._jwks = jwt.PyJWKSet.from_dict(data)
+        log.debug("Fetched JWKS", url=self.url, keys=len(self._jwks.keys))
+
+    async def _fetch_remote_jwks(self) -> dict[str, Any] | None:
+        try:
+            log.debug(
+                "Fetching JWKS",
+                url=self.url,
+                last_fetched_secs_ago=int(time.monotonic() - self.fetched_at) 
if self.fetched_at else None,
+            )
+            if TYPE_CHECKING:
+                assert self.url
+            self.last_fetch_attempt_at = int(time.monotonic())
+            response = await self.client.get(self.url)
+            response.raise_for_status()
+            self.fetched_at = int(time.monotonic())
+            await response.aread()
+            await response.aclose()
+            return response.json()
+        except Exception:
+            log.exception("Failed to fetch remote JWKS", url=self.url)
+            return None
+
+    def _fetch_local_jwks(self) -> dict[str, Any] | None:
+        try:
+            with open(self.url) as jwks_file:
+                content = json.load(jwks_file)
+            self.fetched_at = int(time.monotonic())
+            return content
+        except Exception:
+            log.exception("Failed to read local JWKS", url=self.url)
+            return None
+
+    def _should_fetch_jwks(self) -> bool:
+        """
+        Check if we need to fetch the JWKS based on the last fetch time and 
the refresh interval.
+
+        If the JWKS URL is local, we only fetch it once. For remote JWKS URLs 
we fetch it based
+        on the refresh interval if refreshing has been enabled with a minimum 
interval between
+        attempts. The fetcher functions set the fetched_at timestamp to the 
current monotonic time
+        when the JWKS is fetched.
+        """
+        if not self.url.startswith("http"):
+            # Fetch local JWKS only if not already loaded
+            # This could be improved in future by looking at mtime of file.
+            return not self._jwks
+        # For remote fetches we check if the JWKS is not loaded (fetched_at = 
0) or if the last fetch was more than
+        # refresh_interval_secs ago and the last fetch attempt was more than 
refresh_retry_interval_secs ago
+        now = time.monotonic()
+        return self.refresh_jwks and (
+            not self._jwks
+            or (
+                self.fetched_at == 0
+                or (
+                    now - self.fetched_at > self.refresh_interval_secs
+                    and now - self.last_fetch_attempt_at > 
self.refresh_retry_interval_secs
+                )
+            )
+        )
+
+    async def get_key(self, kid: str) -> jwt.PyJWK:
+        """Fetch the JWKS and find the matching key for the token."""
+        await self.fetch_jwks()
+
+        if self._jwks:
+            return self._jwks[kid]
+
+        # It didn't load!
+        raise KeyError(f"Key ID {kid} not found in keyset")
+
+
+def _conf_factory(section, key, **kwargs):
+    def factory() -> str:
+        from airflow.configuration import conf
+
+        return conf.get(section, key, **kwargs, suppress_warnings=True)  # 
type: ignore[return-value]
+
+    return factory
+
+
+def _conf_list_factory(section, key, first_only: bool = False, **kwargs):
+    def factory() -> list[str] | str:
+        from airflow.configuration import conf
+
+        val = conf.getlist(section, key, **kwargs, suppress_warnings=True)
+
+        if first_only and val:
+            return val[0]
+        return val
+
+    return factory
+
+
+def _sec_to_timedelta(what: int | timedelta) -> timedelta:
+    if isinstance(what, timedelta):
+        return what
+    return timedelta(seconds=what)
+
+
+def _to_list(val: str | list[str]) -> list[str]:
+    if isinstance(val, str):
+        val = [val]
+    return val
+
+
[email protected](kw_only=True)
+class JWTValidator:
+    """
+    Validate the claims and validitory of a JWT.
+
+    This will either validate the JWT is signed with the symmetric key if 
``secret_key`` is passed, or else
+    that it is signed by one of the public keys in the keyset in ``jwks`` 
attribute.
+    """
+
+    jwks: JWKS | None = None
+    secret_key: str | None = attrs.field(repr=False, default=None, 
converter=lambda v: None if v == "" else v)
+    issuer: str | list[str] | None = attrs.field(
+        factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None)
+    )
+    # By default, we just validate these
+    required_claims: frozenset[str] = frozenset({"exp", "iat", "nbf"})
+    audience: str | Sequence[str]
+    algorithm: list[str] = attrs.field(default=["GUESS"], converter=_to_list)
+
+    leeway: timedelta = attrs.field(default=timedelta(seconds=5), 
converter=_sec_to_timedelta)
+
+    def __attrs_post_init__(self):
+        if not (self.jwks is None) ^ (self.secret_key is None):
+            raise ValueError("Exactly one of priavte_key and secret_key must 
be specified")
+
+        if self.algorithm == ["GUESS"]:
+            if self.jwks:
+                # TODO: We could probably populate this from the jwks document?
+                raise ValueError("Cannot guess the algorithm when using JWKS")
+            else:
+                self.algorithm = ["HS512"]
+
+    def _get_kid_from_header(self, unvalidated: str) -> str:
+        header = jwt.get_unverified_header(unvalidated)
+        if "kid" not in header:
+            raise ValueError("Missing 'kid' in token header")
+        return header["kid"]
+
+    async def _get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK:
+        if self.secret_key:
+            return self.secret_key
+
+        if TYPE_CHECKING:
+            assert self.jwks is not None
+
+        kid = self._get_kid_from_header(unvalidated)
+        return await self.jwks.get_key(kid)
+
+    def validated_claims(
+        self, unvalidated: str, extra_claims: dict[str, Any] | None = None
+    ) -> dict[str, Any]:
+        return async_to_sync(self.avalidated_claims)(unvalidated, extra_claims)
+
+    async def avalidated_claims(
+        self, unvalidated: str, extra_claims: dict[str, Any] | None = None
+    ) -> dict[str, Any]:
+        """Decode the JWT token, returning the validated claims or raising an 
exception."""
+        key = await self._get_validation_key(unvalidated)
+        claims = jwt.decode(
+            unvalidated,
+            key,
+            audience=self.audience,
+            issuer=self.issuer,
+            options={"require": self.required_claims},
+            algorithms=self.algorithm,
+            leeway=self.leeway,
+        )
+
+        # Validate additional claims if provided
+        if extra_claims:
+            for claim, expected_value in extra_claims.items():
+                if expected_value["essential"] and (
+                    claim not in claims or claims[claim] != 
expected_value["value"]
+                ):
+                    raise InvalidClaimError(claim)
+
+        return claims
+
+
+def _pem_to_key(pem_data: str | bytes | AllowedPrivateKeys) -> 
AllowedPrivateKeys:
+    if isinstance(pem_data, str):
+        pem_data = pem_data.encode()
+    elif not isinstance(pem_data, bytes):
+        # Assume it's already a key object
+        return pem_data
+
+    return load_pem_private_key(pem_data, password=None)  # type: 
ignore[return-value]
+
+
[email protected](repr=False, kw_only=True)
+class JWTGenerator:
+    """Generate JWT tokens."""
+
+    _private_key: AllowedPrivateKeys | None = attrs.field(
+        repr=False, alias="private_key", converter=_pem_to_key
+    )
+    """
+    Private key to sign generated tokens.
+
+    Should be either a private key object from the cryptography module, or a 
PEM-encoded byte string
+    """
+    _secret_key: str | None = attrs.field(
+        repr=False,
+        alias="secret_key",
+        default=None,
+        converter=lambda v: None if v == "" else v,
+    )
+    """A pre-shared secret key to sign tokens with symmetric encryption"""
+
+    kid: str = attrs.field()
+    valid_for: timedelta = attrs.field(converter=_sec_to_timedelta)
+    audience: str
+    issuer: str | list[str] | None = attrs.field(
+        factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True, 
fallback=None)
+    )
+    algorithm: str = attrs.field(factory=_conf_factory("api_auth", 
"jwt_algorithm", fallback="GUESS"))
+
+    @_private_key.default
+    def _load_key_from_configured_file(self) -> bytes | None:
+        from airflow.configuration import conf
+
+        path = conf.get("api_auth", "jwt_private_key_path", fallback=None)
+        if not path:
+            return None
+
+        with open(path, mode="rb") as fh:
+            return fh.read()
+
+    @kid.default
+    def _generate_kid(self):
+        if not self._private_key:
+            return "not-used"
+
+        info = key_to_jwk_dict(self._private_key)
+        return info["kid"]
+
+    def __attrs_post_init__(self):
+        if not (self._private_key is None) ^ (self._secret_key is None):
+            raise ValueError("Exactly one of priavte_key and secret_key must 
be specified")
+
+        if self.algorithm == "GUESS":
+            if self._private_key:
+                ...

Review Comment:
   Because I never finished this block! oops.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to