136 lines
4.8 KiB
Python
136 lines
4.8 KiB
Python
from jwcrypto import jwk
|
|
import uuid
|
|
from starlette.responses import JSONResponse
|
|
from starlette.requests import Request
|
|
from typing import Any
|
|
|
|
import jwt
|
|
import time
|
|
import json
|
|
import hashlib
|
|
import httpx
|
|
import logging
|
|
|
|
from jwt import PyJWK, PyJWKClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
AUTH_HEADER_PREFIX = 'Bearer '
|
|
|
|
class PushNotificationAuth:
|
|
def _calculate_request_body_sha256(self, data: dict[str, Any]):
|
|
"""Calculates the SHA256 hash of a request body.
|
|
|
|
This logic needs to be same for both the agent who signs the payload and the client verifier.
|
|
"""
|
|
body_str = json.dumps(
|
|
data,
|
|
ensure_ascii=False,
|
|
allow_nan=False,
|
|
indent=None,
|
|
separators=(",", ":"),
|
|
)
|
|
return hashlib.sha256(body_str.encode()).hexdigest()
|
|
|
|
class PushNotificationSenderAuth(PushNotificationAuth):
|
|
def __init__(self):
|
|
self.public_keys = []
|
|
self.private_key_jwk: PyJWK = None
|
|
|
|
@staticmethod
|
|
async def verify_push_notification_url(url: str) -> bool:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
try:
|
|
validation_token = str(uuid.uuid4())
|
|
response = await client.get(
|
|
url,
|
|
params={"validationToken": validation_token}
|
|
)
|
|
response.raise_for_status()
|
|
is_verified = response.text == validation_token
|
|
|
|
logger.info(f"Verified push-notification URL: {url} => {is_verified}")
|
|
return is_verified
|
|
except Exception as e:
|
|
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
|
|
|
|
return False
|
|
|
|
def generate_jwk(self):
|
|
key = jwk.JWK.generate(kty='RSA', size=2048, kid=str(uuid.uuid4()), use="sig")
|
|
self.public_keys.append(key.export_public(as_dict=True))
|
|
self.private_key_jwk = PyJWK.from_json(key.export_private())
|
|
|
|
def handle_jwks_endpoint(self, _request: Request):
|
|
"""Allow clients to fetch public keys.
|
|
"""
|
|
return JSONResponse({
|
|
"keys": self.public_keys
|
|
})
|
|
|
|
def _generate_jwt(self, data: dict[str, Any]):
|
|
"""JWT is generated by signing both the request payload SHA digest and time of token generation.
|
|
|
|
Payload is signed with private key and it ensures the integrity of payload for client.
|
|
Including iat prevents from replay attack.
|
|
"""
|
|
|
|
iat = int(time.time())
|
|
|
|
return jwt.encode(
|
|
{"iat": iat, "request_body_sha256": self._calculate_request_body_sha256(data)},
|
|
key=self.private_key_jwk,
|
|
headers={"kid": self.private_key_jwk.key_id},
|
|
algorithm="RS256"
|
|
)
|
|
|
|
async def send_push_notification(self, url: str, data: dict[str, Any]):
|
|
jwt_token = self._generate_jwt(data)
|
|
headers = {'Authorization': f"Bearer {jwt_token}"}
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
try:
|
|
response = await client.post(
|
|
url,
|
|
json=data,
|
|
headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
logger.info(f"Push-notification sent for URL: {url}")
|
|
except Exception as e:
|
|
logger.warning(f"Error during sending push-notification for URL {url}: {e}")
|
|
|
|
class PushNotificationReceiverAuth(PushNotificationAuth):
|
|
def __init__(self):
|
|
self.public_keys_jwks = []
|
|
self.jwks_client = None
|
|
|
|
async def load_jwks(self, jwks_url: str):
|
|
self.jwks_client = PyJWKClient(jwks_url)
|
|
|
|
async def verify_push_notification(self, request: Request) -> bool:
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
|
|
print("Invalid authorization header")
|
|
return False
|
|
|
|
token = auth_header[len(AUTH_HEADER_PREFIX):]
|
|
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
|
|
|
|
decode_token = jwt.decode(
|
|
token,
|
|
signing_key,
|
|
options={"require": ["iat", "request_body_sha256"]},
|
|
algorithms=["RS256"],
|
|
)
|
|
|
|
actual_body_sha256 = self._calculate_request_body_sha256(await request.json())
|
|
if actual_body_sha256 != decode_token["request_body_sha256"]:
|
|
# Payload signature does not match the digest in signed token.
|
|
raise ValueError("Invalid request body")
|
|
|
|
if time.time() - decode_token["iat"] > 60 * 5:
|
|
# Do not allow push-notifications older than 5 minutes.
|
|
# This is to prevent replay attack.
|
|
raise ValueError("Token is expired")
|
|
|
|
return True
|