from jwcrypto import jwk import uuid from starlette.responses import JSONResponse from starlette.requests import Request from typing import Any import jwt import time import json import hashlib import httpx import logging from jwt import PyJWK, PyJWKClient logger = logging.getLogger(__name__) AUTH_HEADER_PREFIX = 'Bearer ' class PushNotificationAuth: def _calculate_request_body_sha256(self, data: dict[str, Any]): """Calculates the SHA256 hash of a request body. This logic needs to be same for both the agent who signs the payload and the client verifier. """ body_str = json.dumps( data, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), ) return hashlib.sha256(body_str.encode()).hexdigest() class PushNotificationSenderAuth(PushNotificationAuth): def __init__(self): self.public_keys = [] self.private_key_jwk: PyJWK = None @staticmethod async def verify_push_notification_url(url: str) -> bool: async with httpx.AsyncClient(timeout=10) as client: try: validation_token = str(uuid.uuid4()) response = await client.get( url, params={"validationToken": validation_token} ) response.raise_for_status() is_verified = response.text == validation_token logger.info(f"Verified push-notification URL: {url} => {is_verified}") return is_verified except Exception as e: logger.warning(f"Error during sending push-notification for URL {url}: {e}") return False def generate_jwk(self): key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig") self.public_keys.append(key.export_public(as_dict=True)) self.private_key_jwk = PyJWK.from_json(key.export_private()) def handle_jwks_endpoint(self, _request: Request): """Allow clients to fetch public keys. """ return JSONResponse({ "keys": self.public_keys }) def _generate_jwt(self, data: dict[str, Any]): """JWT is generated by signing both the request payload SHA digest and time of token generation. Payload is signed with private key and it ensures the integrity of payload for client. Including iat prevents from replay attack. """ iat = int(time.time()) return jwt.encode( {"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)}, key=self.private_key_jwk, headers={"kid": self.private_key_jwk.key_id}, algorithm="RS256" ) async def send_push_notification(self, url: str, data: dict[str, Any]): jwt_token = self._generate_jwt(data) headers = {'Authorization': f"Bearer {jwt_token}"} async with httpx.AsyncClient(timeout=10) as client: try: response = await client.post( url, json=data, headers=headers ) response.raise_for_status() logger.info(f"Push-notification sent for URL: {url}") except Exception as e: logger.warning(f"Error during sending push-notification for URL {url}: {e}") class PushNotificationReceiverAuth(PushNotificationAuth): def __init__(self): self.public_keys_jwks = [] self.jwks_client = None async def load_jwks(self, jwks_url: str): self.jwks_client = PyJWKClient(jwks_url) async def verify_push_notification(self, request: Request) -> bool: auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): print("Invalid authorization header") return False token = auth_header[len(AUTH_HEADER_PREFIX):] signing_key = self.jwks_client.get_signing_key_from_jwt(token) decode_token = jwt.decode( token, signing_key, options={"require": ["iat", "request_body_sha256"]}, algorithms=["RS256"], ) actual_body_sha256 = self._calculate_request_body_sha256(await request.json()) if actual_body_sha256 != decode_token["request_body_sha256"]: # Payload signature does not match the digest in signed token. raise ValueError("Invalid request body") if time.time() - decode_token["iat"] > 60 * 5: # Do not allow push-notifications older than 5 minutes. # This is to prevent replay attack. raise ValueError("Token is expired") return True