Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented JWS verifying using JWK #999

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 78 additions & 61 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import hashlib
import hmac
import json
Expand All @@ -24,7 +25,7 @@
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
SECP256K1,
Expand Down Expand Up @@ -114,24 +115,20 @@ def get_default_algorithms() -> dict[str, Algorithm]:
}

if has_crypto:
default_algorithms.update(
{
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
"ES512": ECAlgorithm(
ECAlgorithm.SHA512
), # Backward compat for #219 fix
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
"EdDSA": OKPAlgorithm(),
}
)
default_algorithms |= {
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
"ES512": ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
"EdDSA": OKPAlgorithm(),
}

return default_algorithms

Expand All @@ -153,15 +150,14 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
raise NotImplementedError

if (
has_crypto
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
not has_crypto
or not isinstance(hash_alg, type)
or not issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())
else:
return bytes(hash_alg(bytestr).digest())
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())

@abstractmethod
def prepare_key(self, key: Any) -> Any:
Expand Down Expand Up @@ -282,10 +278,7 @@ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
"kty": "oct",
}

if as_dict:
return jwk
else:
return json.dumps(jwk)
return jwk if as_dict else json.dumps(jwk)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
Expand All @@ -296,8 +289,8 @@ def from_jwk(jwk: str | JWKDict) -> bytes:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "oct":
raise InvalidKeyError("Not an HMAC key")
Expand Down Expand Up @@ -345,8 +338,10 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
except ValueError:
try:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError("Could not parse the provided public key.")
except (ValueError, UnsupportedAlgorithm) as e:
raise InvalidKeyError(
"Could not parse the provided public key."
) from e

@overload
@staticmethod
Expand Down Expand Up @@ -394,10 +389,7 @@ def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
else:
raise InvalidKeyError("Not a public or private key")

if as_dict:
return obj
else:
return json.dumps(obj)
return obj if as_dict else json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
Expand All @@ -408,8 +400,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "RSA":
raise InvalidKeyError("Not an RSA key")
Expand Down Expand Up @@ -494,12 +486,15 @@ class ECAlgorithm(Algorithm):
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
def prepare_key(self, key: AllowedECKeys | str | bytes | dict) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key

if isinstance(key, dict):
return self._load_jwk(key)

if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
raise TypeError("Expecting a PEM-formatted key or JWK.")

key_bytes = force_bytes(key)

Expand All @@ -524,6 +519,38 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:

return crypto_key

def _load_jwk(self, jwk: dict) -> EllipticCurvePublicKey:
if jwk.get("kty") != "EC":
raise InvalidKeyError("Not an EC key")

curve = self._get_curve(jwk["crv"])
x = self._base64url_decode(jwk["x"])
y = self._base64url_decode(jwk["y"])

public_numbers = ec.EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve,
)

return public_numbers.public_key()

def _get_curve(self, crv: str) -> ec.EllipticCurve:
if crv == "P-256":
return ec.SECP256R1()
elif crv == "P-384":
return ec.SECP384R1()
elif crv == "P-521":
return ec.SECP521R1()
elif crv == "secp256k1":
return ec.SECP256K1()
else:
raise InvalidKeyError(f"Invalid curve: {crv}")

def _base64url_decode(self, input: str) -> bytes:
input += "=" * (4 - len(input) % 4)
return base64.urlsafe_b64decode(input)

def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))

Expand Down Expand Up @@ -590,10 +617,7 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
key_obj.private_numbers().private_value
).decode()

if as_dict:
return obj
else:
return json.dumps(obj)
return obj if as_dict else json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
Expand All @@ -604,8 +628,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "EC":
raise InvalidKeyError("Not an Elliptic curve key")
Expand Down Expand Up @@ -712,7 +736,7 @@ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
elif "-----BEGIN PRIVATE" in key_str:
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
elif key_str[0:4] == "ssh-":
elif key_str[:4] == "ssh-":
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]

# Explicit check the key to prevent confusing errors from cryptography
Expand Down Expand Up @@ -792,10 +816,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)
return obj if as_dict else json.dumps(obj)

if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
Expand All @@ -817,11 +838,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)

return obj if as_dict else json.dumps(obj)
raise InvalidKeyError("Not a public or private key")

@staticmethod
Expand All @@ -833,14 +850,14 @@ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "OKP":
raise InvalidKeyError("Not an Octet Key Pair")

curve = obj.get("crv")
if curve != "Ed25519" and curve != "Ed448":
if curve not in ["Ed25519", "Ed448"]:
raise InvalidKeyError(f"Invalid curve: {curve}")

if "x" not in obj:
Expand Down
Loading