From ebab128b85d572970bf9e7a2beb10167927232b4 Mon Sep 17 00:00:00 2001 From: Atropa-Solanaceae <89823371+Atropa-Solanaceae@users.noreply.github.com> Date: Mon, 7 Oct 2024 17:41:42 -0400 Subject: [PATCH] Implemented JWS verifying using JWK --- jwt/algorithms.py | 139 ++++++++++++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 61 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9be50b20..4f42cd00 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import hashlib import hmac import json @@ -24,7 +25,7 @@ from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import padding + from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.primitives.asymmetric.ec import ( ECDSA, SECP256K1, @@ -114,24 +115,20 @@ def get_default_algorithms() -> dict[str, Algorithm]: } if has_crypto: - default_algorithms.update( - { - "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), - "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), - "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), - "ES256": ECAlgorithm(ECAlgorithm.SHA256), - "ES256K": ECAlgorithm(ECAlgorithm.SHA256), - "ES384": ECAlgorithm(ECAlgorithm.SHA384), - "ES521": ECAlgorithm(ECAlgorithm.SHA512), - "ES512": ECAlgorithm( - ECAlgorithm.SHA512 - ), # Backward compat for #219 fix - "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), - "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), - "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), - "EdDSA": OKPAlgorithm(), - } - ) + default_algorithms |= { + "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), + "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), + "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), + "ES256": ECAlgorithm(ECAlgorithm.SHA256), + "ES256K": ECAlgorithm(ECAlgorithm.SHA256), + "ES384": ECAlgorithm(ECAlgorithm.SHA384), + "ES521": ECAlgorithm(ECAlgorithm.SHA512), + "ES512": ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix + "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), + "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), + "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), + "EdDSA": OKPAlgorithm(), + } return default_algorithms @@ -153,15 +150,14 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes: raise NotImplementedError if ( - has_crypto - and isinstance(hash_alg, type) - and issubclass(hash_alg, hashes.HashAlgorithm) + not has_crypto + or not isinstance(hash_alg, type) + or not issubclass(hash_alg, hashes.HashAlgorithm) ): - digest = hashes.Hash(hash_alg(), backend=default_backend()) - digest.update(bytestr) - return bytes(digest.finalize()) - else: return bytes(hash_alg(bytestr).digest()) + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return bytes(digest.finalize()) @abstractmethod def prepare_key(self, key: Any) -> Any: @@ -282,10 +278,7 @@ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: "kty": "oct", } - if as_dict: - return jwk - else: - return json.dumps(jwk) + return jwk if as_dict else json.dumps(jwk) @staticmethod def from_jwk(jwk: str | JWKDict) -> bytes: @@ -296,8 +289,8 @@ def from_jwk(jwk: str | JWKDict) -> bytes: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "oct": raise InvalidKeyError("Not an HMAC key") @@ -345,8 +338,10 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: except ValueError: try: return cast(RSAPublicKey, load_pem_public_key(key_bytes)) - except (ValueError, UnsupportedAlgorithm): - raise InvalidKeyError("Could not parse the provided public key.") + except (ValueError, UnsupportedAlgorithm) as e: + raise InvalidKeyError( + "Could not parse the provided public key." + ) from e @overload @staticmethod @@ -394,10 +389,7 @@ def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: else: raise InvalidKeyError("Not a public or private key") - if as_dict: - return obj - else: - return json.dumps(obj) + return obj if as_dict else json.dumps(obj) @staticmethod def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: @@ -408,8 +400,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "RSA": raise InvalidKeyError("Not an RSA key") @@ -494,12 +486,15 @@ class ECAlgorithm(Algorithm): def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: + def prepare_key(self, key: AllowedECKeys | str | bytes | dict) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key + if isinstance(key, dict): + return self._load_jwk(key) + if not isinstance(key, (bytes, str)): - raise TypeError("Expecting a PEM-formatted key.") + raise TypeError("Expecting a PEM-formatted key or JWK.") key_bytes = force_bytes(key) @@ -524,6 +519,38 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: return crypto_key + def _load_jwk(self, jwk: dict) -> EllipticCurvePublicKey: + if jwk.get("kty") != "EC": + raise InvalidKeyError("Not an EC key") + + curve = self._get_curve(jwk["crv"]) + x = self._base64url_decode(jwk["x"]) + y = self._base64url_decode(jwk["y"]) + + public_numbers = ec.EllipticCurvePublicNumbers( + x=int.from_bytes(x, byteorder="big"), + y=int.from_bytes(y, byteorder="big"), + curve=curve, + ) + + return public_numbers.public_key() + + def _get_curve(self, crv: str) -> ec.EllipticCurve: + if crv == "P-256": + return ec.SECP256R1() + elif crv == "P-384": + return ec.SECP384R1() + elif crv == "P-521": + return ec.SECP521R1() + elif crv == "secp256k1": + return ec.SECP256K1() + else: + raise InvalidKeyError(f"Invalid curve: {crv}") + + def _base64url_decode(self, input: str) -> bytes: + input += "=" * (4 - len(input) % 4) + return base64.urlsafe_b64decode(input) + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: der_sig = key.sign(msg, ECDSA(self.hash_alg())) @@ -590,10 +617,7 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: key_obj.private_numbers().private_value ).decode() - if as_dict: - return obj - else: - return json.dumps(obj) + return obj if as_dict else json.dumps(obj) @staticmethod def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: @@ -604,8 +628,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "EC": raise InvalidKeyError("Not an Elliptic curve key") @@ -712,7 +736,7 @@ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: key = load_pem_public_key(key_bytes) # type: ignore[assignment] elif "-----BEGIN PRIVATE" in key_str: key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] - elif key_str[0:4] == "ssh-": + elif key_str[:4] == "ssh-": key = load_ssh_public_key(key_bytes) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography @@ -792,10 +816,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: "crv": crv, } - if as_dict: - return obj - else: - return json.dumps(obj) + return obj if as_dict else json.dumps(obj) if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): d = key.private_bytes( @@ -817,11 +838,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: "crv": crv, } - if as_dict: - return obj - else: - return json.dumps(obj) - + return obj if as_dict else json.dumps(obj) raise InvalidKeyError("Not a public or private key") @staticmethod @@ -833,14 +850,14 @@ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: obj = jwk else: raise ValueError - except ValueError: - raise InvalidKeyError("Key is not valid JSON") + except ValueError as e: + raise InvalidKeyError("Key is not valid JSON") from e if obj.get("kty") != "OKP": raise InvalidKeyError("Not an Octet Key Pair") curve = obj.get("crv") - if curve != "Ed25519" and curve != "Ed448": + if curve not in ["Ed25519", "Ed448"]: raise InvalidKeyError(f"Invalid curve: {curve}") if "x" not in obj: