Source code for rootski.services.auth
"""
Logic for authenticating incoming requests.
The code heavily borrows from this article:
https://gntrm.medium.com/jwt-authentication-with-fastapi-and-aws-cognito-1333f7f2729e
"""
from typing import List, Optional
import httpx
from jose import JWTError, jwk, jwt
from jose.utils import base64url_decode
from loguru import logger
from pydantic import BaseModel
from rootski.config.config import ANON_USER, Config
from rootski.errors import AuthServiceError
from rootski.services.service import Service
[docs]class JsonWebKey(BaseModel):
    """
    Learn about Cognito JWKs here:
    https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
    """
    kid: str  # key ID
    kty: str
    class Config:
        extra = "allow"
[docs]class AuthService(Service):
    _jwks: Optional[JsonWebKeySet] = None
    @classmethod
    def from_config(cls, config: Config):
        return cls(cognito_public_keys_url=config.cognito_public_keys_url)
    def __init__(self, cognito_public_keys_url: str):
        """Abstraction layer around verifying tokens."""
        self.__cognito_public_keys_url = cognito_public_keys_url
    def init(self):
        logger.info("Fetching Cognito Keys")
        self._jwks = get_jwks(self.__cognito_public_keys_url)
        logger.info(f"Fetched these keys: {str(self._jwks.json())}")
    def token_is_valid(self, token: str) -> bool:
        if not self._jwks:
            raise AuthServiceError("The auth service is not initialized. Did you call .init()?")
        if not token_is_well_formed(token=token):
            return False
        logger.info(f"Validating token: {token}")
        return jwt_is_valid(token, self._jwks)
[docs]    def get_token_email(self, token: str) -> Optional[str]:
        """Retrieve the email from the token, or return the anonymous user."""
        logger.info(f"token {token}")
        if not token_is_well_formed(token):
            error_msg = (
                f"Got this error while getting the 'email' from the JWT token {str(e)}"
                + f"\n\nToken: {str(token)}"
            )
            logger.error(error_msg)
            raise AuthServiceError("Error, JWT token is not wellformed. See logs for details.")
        return jwt.get_unverified_claims(token).get("email", ANON_USER)
[docs]def token_is_well_formed(token: str) -> bool:
    """Return ``True`` if the token can be decoded without verifying the signature."""
    try:
        jwt.get_unverified_claims(token)
        jwt.get_unverified_headers(token)
    except JWTError:
        return False
    return True
[docs]def get_jwks(jwk_url: str) -> JsonWebKeySet:
    response = httpx.get(jwk_url)
    return JsonWebKeySet(**response.json())
[docs]def get_token_jwk(token: str, jwks: JsonWebKeySet) -> Optional[JsonWebKey]:
    """Return the Cognito public key whose ID matches the key ID in the token header.
    If our Cognito user pool does not have a matching key, return ``None``... we're
    not going to be able to authenticate this token. :(
    Args:
        token: JWT token from the header of an incoming request
        jwks: Json Web Keys corresponding to our Cognito user pool
    """
    try:
        token_kid = jwt.get_unverified_header(token).get("kid")
    except JWTError as e:
        logger.error(f"Got this error while getting the email from the JWT token {str(e)}")
        raise AuthServiceError("Error while getting email from JWT claims. See logs for details.")
    for key in jwks.keys:
        if key.kid == token_kid:
            return key
[docs]def jwt_is_valid(token: str, jwks: JsonWebKeySet) -> bool:
    """Return ``True`` if the jwt ``token`` was signed by our Cognito user pool identity server."""
    token_jwk: Optional[JsonWebKey] = get_token_jwk(token, jwks)
    if not token_jwk:
        raise AuthServiceError(
            "No public key found! Did you call AuthService.init()? Are the Cognito config values right?"
        )
    hmac_key = jwk.construct(token_jwk.dict())
    message, encoded_signature = token.rsplit(".", 1)
    decoded_signature = base64url_decode(encoded_signature.encode())
    return hmac_key.verify(message.encode(), decoded_signature)