Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Not checking for expired JWT token #238

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from typing import Any, Dict
import re
import time
from typing import Any, Dict, Optional

import jwt

from ..headers_collection import HeadersCollection
from ..request_information import RequestInformation
Expand All @@ -13,18 +17,19 @@


class BaseBearerTokenAuthenticationProvider(AuthenticationProvider):
"""Provides a base class for implementing AuthenticationProvider for Bearer token scheme.
"""
"""Provides a base class for implementing AuthenticationProvider for Bearer token scheme."""

AUTHORIZATION_HEADER = "Authorization"
CLAIMS_KEY = "claims"
AUTHORIZATION_PREFIX = "Bearer"

def __init__(self, access_token_provider: AccessTokenProvider) -> None:
self.access_token_provider = access_token_provider

async def authenticate_request(
self,
request: RequestInformation,
additional_authentication_context: Dict[str, Any] = {}
additional_authentication_context: Dict[str, Any] = None,
) -> None:
"""Authenticates the provided RequestInformation instance using the provided
authorization token
Expand All @@ -34,21 +39,90 @@ async def authenticate_request(
"""
if not request:
raise Exception("Request cannot be null")
if all(
[
additional_authentication_context, self.CLAIMS_KEY
in additional_authentication_context,
request.headers.contains(self.AUTHORIZATION_HEADER)
]
):
request.headers.remove(self.AUTHORIZATION_HEADER)

if not request.request_headers:
request.headers = HeadersCollection()

if additional_authentication_context:
self._check_for_claims_key(request, additional_authentication_context)

self._remove_expired_token(request)

if not request.headers.contains(self.AUTHORIZATION_HEADER):
token = await self.access_token_provider.get_authorization_token(
request.url, additional_authentication_context
)
if token:
request.headers.add(f'{self.AUTHORIZATION_HEADER}', f'Bearer {token}')
request.headers.add(
f"{self.AUTHORIZATION_HEADER}",
f"{self.AUTHORIZATION_PREFIX} {token}",
)
return

def _check_for_claims_key(
self,
request: RequestInformation,
additional_authentication_context: Dict[str, Any],
) -> None:
"""
Checks if the claims key is in the additional authentication context and if the
authorization header is in the request headers. If both conditions are met, it removes
the authorization header from the request headers.

Args:
request (RequestInformation): The request information object.
additional_authentication_context (Dict[str, Any]): Additional context for authentication.
"""

if all(
[
self.CLAIMS_KEY in additional_authentication_context,
request.headers.contains(self.AUTHORIZATION_HEADER),
]
):
request.headers.remove(self.AUTHORIZATION_HEADER)

def _remove_expired_token(self, request: RequestInformation) -> None:
"""
Removes expired tokens from the request headers.

Args:
request (RequestInformation): The request information object.
"""

tokens = request.headers.get(self.AUTHORIZATION_HEADER)

if tokens:
for _token in tokens:
matchs = re.match(rf"{self.AUTHORIZATION_PREFIX} (.*)", _token)
if matchs:
token = matchs.group(0).split(" ")[1]
if self.is_token_expired(token):
request.headers.remove(self.AUTHORIZATION_HEADER)

@staticmethod
def is_token_expired(token: str) -> bool:
"""
Checks if the given token is expired.

Args:
token (str): The token to check.

Returns:
bool: True if the token is expired, False otherwise.
"""

try:
payload: Dict[str, Any] = jwt.decode(
token, options={"verify_signature": False}
)

# Get the expiration time (exp), which is a Unix timestamp
exp: Optional[int] = payload.get("exp", None)

# Check if the current time is past the token's expiration
if exp is not None and time.time() >= exp:
return True
return False
except jwt.DecodeError:
return True # If we can't decode the token, consider it as expired