Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added verifying detached payload JWS with JWK #997

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def find_version(*file_paths) -> str:
string inside.
"""
version_file = read(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
if version_match := re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M
):
return version_match[1]
raise RuntimeError("Unable to find version string.")


Expand Down
139 changes: 78 additions & 61 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import hashlib
import hmac
import json
Expand All @@ -24,7 +25,7 @@
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
SECP256K1,
Expand Down Expand Up @@ -114,24 +115,20 @@ def get_default_algorithms() -> dict[str, Algorithm]:
}

if has_crypto:
default_algorithms.update(
{
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
"ES512": ECAlgorithm(
ECAlgorithm.SHA512
), # Backward compat for #219 fix
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
"EdDSA": OKPAlgorithm(),
}
)
default_algorithms |= {
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
"ES512": ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
"EdDSA": OKPAlgorithm(),
}

return default_algorithms

Expand All @@ -153,15 +150,14 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
raise NotImplementedError

if (
has_crypto
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
not has_crypto
or not isinstance(hash_alg, type)
or not issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())
else:
return bytes(hash_alg(bytestr).digest())
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())

@abstractmethod
def prepare_key(self, key: Any) -> Any:
Expand Down Expand Up @@ -282,10 +278,7 @@ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
"kty": "oct",
}

if as_dict:
return jwk
else:
return json.dumps(jwk)
return jwk if as_dict else json.dumps(jwk)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
Expand All @@ -296,8 +289,8 @@ def from_jwk(jwk: str | JWKDict) -> bytes:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "oct":
raise InvalidKeyError("Not an HMAC key")
Expand Down Expand Up @@ -345,8 +338,10 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
except ValueError:
try:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError("Could not parse the provided public key.")
except (ValueError, UnsupportedAlgorithm) as e:
raise InvalidKeyError(
"Could not parse the provided public key."
) from e

@overload
@staticmethod
Expand Down Expand Up @@ -394,10 +389,7 @@ def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
else:
raise InvalidKeyError("Not a public or private key")

if as_dict:
return obj
else:
return json.dumps(obj)
return obj if as_dict else json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
Expand All @@ -408,8 +400,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "RSA":
raise InvalidKeyError("Not an RSA key")
Expand Down Expand Up @@ -494,12 +486,15 @@ class ECAlgorithm(Algorithm):
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
def prepare_key(self, key: AllowedECKeys | str | bytes | dict) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key

if isinstance(key, dict):
return self._load_jwk(key)

if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
raise TypeError("Expecting a PEM-formatted key or JWK.")

key_bytes = force_bytes(key)

Expand All @@ -524,6 +519,38 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:

return crypto_key

def _load_jwk(self, jwk: dict) -> EllipticCurvePublicKey:
if jwk.get("kty") != "EC":
raise InvalidKeyError("Not an EC key")

curve = self._get_curve(jwk["crv"])
x = self._base64url_decode(jwk["x"])
y = self._base64url_decode(jwk["y"])

public_numbers = ec.EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve,
)

return public_numbers.public_key()

def _get_curve(self, crv: str) -> ec.EllipticCurve:
if crv == "P-256":
return ec.SECP256R1()
elif crv == "P-384":
return ec.SECP384R1()
elif crv == "P-521":
return ec.SECP521R1()
elif crv == "secp256k1":
return ec.SECP256K1()
else:
raise InvalidKeyError(f"Invalid curve: {crv}")

def _base64url_decode(self, input: str) -> bytes:
input += "=" * (4 - len(input) % 4)
return base64.urlsafe_b64decode(input)

def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))

Expand Down Expand Up @@ -590,10 +617,7 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
key_obj.private_numbers().private_value
).decode()

if as_dict:
return obj
else:
return json.dumps(obj)
return obj if as_dict else json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
Expand All @@ -604,8 +628,8 @@ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "EC":
raise InvalidKeyError("Not an Elliptic curve key")
Expand Down Expand Up @@ -712,7 +736,7 @@ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
elif "-----BEGIN PRIVATE" in key_str:
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
elif key_str[0:4] == "ssh-":
elif key_str[:4] == "ssh-":
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]

# Explicit check the key to prevent confusing errors from cryptography
Expand Down Expand Up @@ -792,10 +816,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)
return obj if as_dict else json.dumps(obj)

if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
Expand All @@ -817,11 +838,7 @@ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)

return obj if as_dict else json.dumps(obj)
raise InvalidKeyError("Not a public or private key")

@staticmethod
Expand All @@ -833,14 +850,14 @@ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
except ValueError as e:
raise InvalidKeyError("Key is not valid JSON") from e

if obj.get("kty") != "OKP":
raise InvalidKeyError("Not an Octet Key Pair")

curve = obj.get("crv")
if curve != "Ed25519" and curve != "Ed448":
if curve not in ["Ed25519", "Ed448"]:
raise InvalidKeyError(f"Invalid curve: {curve}")

if "x" not in obj:
Expand Down
2 changes: 1 addition & 1 deletion jwt/api_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, keys: list[JWKDict]) -> None:
# skip unusable keys
continue

if len(self.keys) == 0:
if not self.keys:
raise PyJWKSetError(
"The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
)
Expand Down
15 changes: 4 additions & 11 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def encode(
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
segments = []

# declare a new var to narrow the type for type checkers
if algorithm is None:
if isinstance(key, PyJWK):
Expand All @@ -125,8 +123,7 @@ def encode(

# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
if headers_alg := headers.get("alg"):
algorithm_ = headers["alg"]

headers_b64 = headers.get("b64")
Expand All @@ -138,7 +135,7 @@ def encode(

if headers:
self._validate_headers(headers)
header.update(headers)
header |= headers

if not header["typ"]:
del header["typ"]
Expand All @@ -153,12 +150,8 @@ def encode(
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()

segments.append(base64url_encode(json_header))

if is_payload_detached:
msg_payload = payload
else:
msg_payload = base64url_encode(payload)
segments = [base64url_encode(json_header)]
msg_payload = payload if is_payload_detached else base64url_encode(payload)
segments.append(msg_payload)

# Segments
Expand Down
29 changes: 16 additions & 13 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _decode_payload(self, decoded: dict[str, Any]) -> Any:
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
raise DecodeError(f"Invalid payload string: {e}") from e
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload
Expand Down Expand Up @@ -268,8 +268,10 @@ def _validate_iat(
) -> None:
try:
iat = int(payload["iat"])
except ValueError:
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
except ValueError as e:
raise InvalidIssuedAtError(
"Issued At claim (iat) must be an integer."
) from e
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")

Expand All @@ -281,8 +283,8 @@ def _validate_nbf(
) -> None:
try:
nbf = int(payload["nbf"])
except ValueError:
raise DecodeError("Not Before claim (nbf) must be an integer.")
except ValueError as e:
raise DecodeError("Not Before claim (nbf) must be an integer.") from e

if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)")
Expand All @@ -295,8 +297,8 @@ def _validate_exp(
) -> None:
try:
exp = int(payload["exp"])
except ValueError:
raise DecodeError("Expiration Time claim (exp) must be an integer.")
except ValueError as e:
raise DecodeError("Expiration Time claim (exp) must be an integer.") from e

if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired")
Expand Down Expand Up @@ -358,12 +360,13 @@ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if "iss" not in payload:
raise MissingRequiredClaimError("iss")

if isinstance(issuer, Sequence):
if payload["iss"] not in issuer:
raise InvalidIssuerError("Invalid issuer")
else:
if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer")
if (
isinstance(issuer, Sequence)
and payload["iss"] not in issuer
or not isinstance(issuer, Sequence)
and payload["iss"] != issuer
):
raise InvalidIssuerError("Invalid issuer")


_jwt_global_obj = PyJWT()
Expand Down
Loading