GitHub user seniuts-b2 added a comment to the discussion: Airflow API Returns 
403 Forbidden When Using Azure AD Authentication via Custom API Backend

So, here’s a summary of the steps I took to make the setup work successfully. 
It can be valuable for someone else. But feel free to share with me better 
solution.

Env variables:
```
AAD_TENANT_ID = os.getenv("AAD_TENANT_ID")
AAD_CLIENT_ID = os.getenv("AAD_CLIENT_ID")
AAD_CLIENT_SECRET = os.getenv("AAD_CLIENT_SECRET")
# Airflow API Configuration
AIRFLOW_BASE_URL = "https://{{ your Airflow URL }}/auth/fab/v1"
# Azure AD OAuth2 Token URL
AUTH_URL = 
f"https://login.microsoftonline.com/{AAD_TENANT_ID}/oauth2/v2.0/token";
# Required permission scope for Airflow API
SCOPE = (
    f"api://{AAD_CLIENT_ID}/.default"
)
# Azure AD OAuth2 Token URL
TOKEN_URL = 
f"https://login.microsoftonline.com/{AAD_TENANT_ID}/oauth2/v2.0/token";
```

Here is my function to get Azure AD token:
```
import requests
import time

from oauthlib.oauth2 import BackendApplicationClient
from requests.exceptions import HTTPError, SSLError
from requests_oauthlib import OAuth2Session

# Function: Authenticate with Azure AD
# -----------------------------------------
def get_azure_ad_token_cached():
    """
    Authenticate with Azure AD using the client credentials flow and get an 
access token.

    Returns a valid Azure AD token. If a cached token exists and is still 
valid, returns the cached one. Otherwise, requests a new token from AAD.
    """
    token_cache = {"CACHED_TOKEN": None, "TOKEN_EXPIRES_AT": 0}

    # If there's a valid token in cache (and not expiring), return it
    if token_cache["CACHED_TOKEN"] and time.time() < 
token_cache["TOKEN_EXPIRES_AT"] - 60:
        return token_cache["CACHED_TOKEN"]

    # Otherwise, fetch a new token
    logger.info("Fetching a new token from Azure AD.")
    client = BackendApplicationClient(client_id=AAD_CLIENT_ID)
    oauth2_session = OAuth2Session(client=client)

    # payload = {
    #     "grant_type": "client_credentials",
    #     "client_id": AAD_CLIENT_ID,
    #     "client_secret": AAD_CLIENT_SECRET,
    #     "scope": SCOPE
    # }
    # logger.info(f"!!! JUST FOR TEST: Token request payload: {payload}")
    # try:
    #     response = requests.post(AUTH_URL, data=payload, timeout=10)
    #     logger.info(f"!!! JUST FOR TEST !!! Azure AD 
response:\n{response.text}")
    #     response.raise_for_status()
    #     token_response = response.json()
    #     return token_response["access_token"]
    # except HTTPError as http_err:
    #     logger.error(f"HTTP error while fetching token: {http_err}")
    #     raise http_err
    # except Exception as err:
    #     logger.error(f"Error obtaining Azure AD token: {err}")
    #     raise err

    # Fetch Access Token
    try:
        token_response = oauth2_session.fetch_token(
            token_url=TOKEN_URL,
            client_id=AAD_CLIENT_ID,
            client_secret=AAD_CLIENT_SECRET,
            scope=SCOPE,
        )
        logger.info("Token Acquired Successfully")
        # logger.info(f"!!! JUST FOR TEST !!! Azure AD 
token:\n{token_response}")
    except requests.exceptions.RequestException as e:
        logger.info(f"Failed to fetch token: {e}")
        raise requests.exceptions.RequestException(e) from e

    expires_in = token_response.get("expires_in", 3599)
    token_cache["TOKEN_EXPIRES_AT"] = time.time() + expires_in

    # Construct the actual bearer token string
    token_cache["CACHED_TOKEN"] = f"Bearer {token_response.get('access_token', 
'')}"

    return token_cache["CACHED_TOKEN"]
```

Here is the full code of API Auth backend:
Azure AD token 
[overview](https://learn.microsoft.com/en-us/azure/active-directory-b2c/tokens-overview)
```
"""
Provides an authentication backend for Airflow using Azure Active Directory 
(Azure AD).

It includes functionality for validating Azure AD JWT tokens, fetching JWKS 
keys, and enforcing
authentication on API endpoints.
"""

from __future__ import annotations

import json
import logging
import os
from functools import wraps
from typing import Callable, TypeVar, cast

import jwt
import requests
from cryptography.hazmat.primitives import serialization
from flask import Response, current_app, request
from jwt.algorithms import RSAAlgorithm

T = TypeVar("T", bound=Callable)  # TypeVar for function decorators

# ------------------------------
# Logging Configuration
# ------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - 
%(message)s")
logger = logging.getLogger(__name__)

# ------------------------------
# Azure AD Configuration
# ------------------------------
# !!! IMPORTANT: AAD_TENANT_ID and AAD_CLIENT_ID parameters should be passed as 
environment variables through the values.yaml file webserver section
AAD_TENANT_ID = os.getenv("AAD_TENANT_ID")
AAD_CLIENT_ID = os.getenv("AAD_CLIENT_ID")

JWKS_URL = 
f"https://login.microsoftonline.com/{AAD_TENANT_ID}/discovery/v2.0/keys";  # 
JSON Web Key Set (JWKS) URL
ISSUER = f"https://sts.windows.net/{AAD_TENANT_ID}/";  # Issuer (iss) claim in 
the token
AUDIENCE = f"api://{AAD_CLIENT_ID}"  # Expected audience (should match the 
Application ID URI)
LEEWAY_SECONDS = 60  # Allow 60 seconds of clock drift for `exp`, `nbf`, `iat` 
validation

if not AAD_TENANT_ID or not AAD_CLIENT_ID:
    msg = "Missing required environment variables (should be passed as env 
variable through values.yaml file webserver secrion): AAD_TENANT_ID, 
AAD_CLIENT_ID"
    logger.error(msg)
    raise ValueError(msg)

logger.info(f"Fetching JWKS from: {JWKS_URL}")

# Cache JWKS keys to avoid frequent network requests
jwks_cache: dict[str, dict] = {}


def get_azure_jwks():
    """
    Fetch and cache Azure AD JWKS (JSON Web Key Set) to verify token signatures.

    JWKS is a public key repository published by Azure AD.
    It contains multiple keys that are used to verify JWT signatures.

    Azure AD JWKS URL looks like: 
https://login.microsoftonline.com/{AAD_TENANT_ID}/discovery/v2.0/keys

    Azure AD rotates keys from time to time, so fetching them once per process 
is enough:
    
https://learn.microsoft.com/en-us/entra/identity-platform/signing-key-rollover#:~:text=metadata%20document.-,For%20security%20purposes%2C%20the%20Microsoft%20identity%20platform%E2%80%99s%20signing%20key%20rolls%20on,a%20key%20rollover%20event%20no%20matter%20how%20frequently%20it%20may%20occur.,-If%20your%20application
    """
    global jwks_cache  # noqa: PLW0603
    if not jwks_cache:
        logger.info("Fetching JWKS from Azure AD...")
        response = requests.get(JWKS_URL, timeout=10)
        response.raise_for_status()
        jwks_cache = response.json()
        logger.info("JWKS fetched and cached.")
    return jwks_cache


def validate_azure_token(token: str) -> bool:
    """
    Token (str): there is received Airflow API request token of Azure AD JWT 
access token.

    JWT (JSON Web Token) has three parts:
    # Header (contains metadata like algorithm and key ID kid)
    # Payload (contains claims like iss, exp, aud, etc.)
    # Signature (ensures integrity)

    Validate an Azure AD JWT access token workflow:
    1. Extract the JWT Header → Get the "kid"
    2. Fetch JWKS from Azure → Get the list of public keys
    3. Find the matching kid → Use it to verify the JWT
    4. Decode JWT using the correct key → Validate signature and claims
    5. If everything is valid, grant access

    Short workflow: The function fetches the JWKS, finds the correct key, and 
verifies the token signature.
    """
    logger.info("Validating Azure AD access token...")

    if len(token.split(".")) != 3:  # noqa: PLR2004
        logger.error("Invalid token format. Access denied.")
        return False

    # Decode the JWT header to get the key ID (kid):
    ## The JWT Header contains a "kid" (Key ID) field, which tells which public 
key was used to sign the JWT.
    header = jwt.get_unverified_header(token)
    # Fetch JWKS and find the correct key.
    jwks = get_azure_jwks()
    # logger.info(f"JWKS: {jwks}")
    key = next((json.dumps(k) for k in jwks["keys"] if k["kid"] == 
header["kid"]), None)
    if not key:
        logger.error("Token key ID (kid) not found in JWKS. Token is invalid.")
        return False
    logger.info("Matching JWKS key found. Verifying token signature...")

    # Decode and verify the token:
    public_key = RSAAlgorithm.from_jwk(key)
    public_key_bytes = public_key.public_bytes(
        encoding=serialization.Encoding.PEM, 
format=serialization.PublicFormat.SubjectPublicKeyInfo
    )

    try:
        ## Verify the expected jwt token claims and signature and return the 
decoded token claims: https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode
        decoded_token = jwt.decode(
            jwt=token,
            key=public_key_bytes,
            algorithms=["RS256"],
            # extended decoding and validation options
            audience=AUDIENCE,  # Ensures token is meant for this API
            issuer=ISSUER,  # Ensures token was issued by Azure AD
            # leeway=LEEWAY_SECONDS,  # Allow 60 seconds of clock drift for 
exp/nbf validation
            options={
                "verify_signature": True,  # Ensure signature is verified
                "require": ["exp", "iat", "nbf"],  # Ensure these claims exist
                "verify_exp": True,  # Ensure token is not expired
                "verify_iat": True,  # Ensure issued-at (iat) is valid
                "verify_nbf": True,  # Ensure not-before (nbf) is valid
                "verify_aud": True,  # Ensure audience (aud) claim is valid
                "verify_iss": True,  # Ensure issuer (iss) claim is valid
                "strict_aud": False,  # check that the aud claim is a single 
value (not a list), and matches audience exactly
            },
        )
        logger.info(f"Decoded token: {decoded_token}")
        logger.info("Token is valid and contains required scope. Access 
granted.")
        return decoded_token
    except jwt.ExpiredSignatureError as e:
        logger.error(f"Token has expired. Access denied: {e}")
    except jwt.InvalidAudienceError as e:
        logger.error(f"Invalid audience. Expected: {AUDIENCE}: {e}")
    except jwt.InvalidIssuerError as e:
        logger.error(f"Invalid issuer. Expected: {ISSUER}: {e}")
    except jwt.InvalidTokenError as e:
        logger.error(f"Invalid token error: {e}")
    return False


def lookup_airflow_user(token_claims: dict):
    """Finds an Airflow user or App ID based on token claims."""
    security_manager = current_app.appbuilder.sm  # Access the security manager

    # Get user from token (try email first, fallback to appid)
    user_email = token_claims.get("email") or 
token_claims.get("preferred_username")
    user_appid = token_claims.get("appid")  # App ID for service accounts
    logger.info(f"Looking for user with email: '{user_email}' or appid: 
'{user_appid}'")

    user = None
    if user_email:
        user = security_manager.find_user(email=user_email)
    if not user and user_appid:
        logger.info(f"Looking for service account user with appid: 
{user_appid}")
        user = security_manager.find_user(username=user_appid)  # Match appid 
to a user
    if not user:
        logger.error(f"User {user_email or user_appid} not found in Airflow. 
Access denied.")
        return None
    if not user.is_active:
        logger.error(f"User {user.username} is inactive. Access denied.")
        return None
    return user


def set_current_airflow_user(user):
    """Sets the current Airflow user in the request context."""
    current_app.appbuilder.sm.lm._update_request_context_with_user(user=user)


def requires_authentication(function: T):
    """Decorator to enforce authentication on API endpoints using Azure AD 
token."""

    @wraps(function)
    def decorated(*args, **kwargs):
        logger.info(" ### Custome Airflow API Backend to validate Azure AD 
Authentication ###")
        logger.info(f"Executing function: {function.__name__}")  # Extract 
function name
        logger.info(f"request: {request}")

        auth_header = request.headers.get("Authorization")
        if not auth_header or not auth_header.startswith("Bearer "):
            logger.error("Missing or malformed Authorization header. Access 
denied.")
            return Response("Unauthorized", 401, {"WWW-Authenticate": "Bearer"})

        # As the token is the part of Authorization header we can extract it by 
splitting the header (authorization header is in the format "Bearer <token>")
        token = auth_header.split(" ")[1]
        token_claims = validate_azure_token(token)
        if not token_claims:
            logger.error("Invalid or expired token. Access denied.")
            return Response("Unauthorized", 401, {"WWW-Authenticate": "Bearer"})

        user = lookup_airflow_user(token_claims)
        if not user:
            return Response("Forbidden", 403)

        set_current_airflow_user(user)

        logger.info("Authentication successful. Processing request.")
        return function(*args, **kwargs)

    return cast(T, decorated)


def init_app(_):
    """Initialize the authentication backend required by Airflow."""

```

Here are the needed updates in _Azure App registrations_ :
- Add Application ID URI (Expose an API section):
![image](https://github.com/user-attachments/assets/d101840a-6cbc-4acc-b326-4a79ebd0d49c)
- Add API permissions:
![image](https://github.com/user-attachments/assets/fe86a44c-6dfc-48b6-914d-ef5d6aa47edb)

Permissions have to be requested for your API.
It means, you request/add _Admin_  permissions thought _App registration_ for 
the same  _App registrtions_ :

![image](https://github.com/user-attachments/assets/46658132-ed39-4172-8893-938b0b61cd21)

![image](https://github.com/user-attachments/assets/157148dc-c71f-4cdc-8905-e268cc363184)

![image](https://github.com/user-attachments/assets/e4dadd39-fa11-49ed-9a4e-a738e38f46a0)

_Admin_ is a group added in _App roles_ :
![image](https://github.com/user-attachments/assets/bad7a3d5-ba15-4eee-80e2-b40088d74b00))

Permissions have to be granted by the Admin (Admin consent required).
Adding _Admin_ permissions to your API (__App registrations__) passes 
corresponded claim to AD token.

-  **Important Note:** Changes made in **App registrations (AppR)** and 
assigning the **Admin** role to the AppR do **not** automatically grant access 
in Airflow. In other words, Airflow doesn’t recognize the AppR as a user by 
default.  
    To fix this, we need to explicitly **add the AppR's `clientID` as an Admin 
user in Airflow**. You can do this by following these steps:
    -   Connect to an Airflow pod in your Kubernetes cluster, e.g., the 
webserver:        
        `kubectl exec -it airflow-webserver-584df9db77-h65g5 -n airflow -- 
/bin/bash`
    -   Inside the pod, run the following Airflow CLI command:
        `airflow users create  --username clientID  --password $(python3 -c 
'import secrets; print(secrets.token_hex(16))')  --firstname Service   
--lastname Account  --role Admin  --email [email protected]`
        Here, `--username` should be set to the **client ID** of the AppR.
        We can generate own secure password as shown above, or simply use 
`--use-random-password`.  
        Reference: [Airflow CLI – create 
user](https://airflow.apache.org/docs/apache-airflow-providers-fab/stable/cli-ref.html#create)

Once all of these were done, Airflow will recognize the **AppR clientID** as an 
**Admin** user, allowing it to authenticate and access Airflow APIs as expected.



GitHub link: 
https://github.com/apache/airflow/discussions/47029#discussioncomment-12852898

----
This is an automatically sent email for [email protected].
To unsubscribe, please send an email to: [email protected]

Reply via email to