From e94c99eacbdc63d75051a81c8db6549efeb45e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Brunner?= Date: Thu, 19 Sep 2024 10:22:08 +0200 Subject: [PATCH] OpenID Connect: Add hook to be able to customize role creation --- doc/integrator/authentication_oidc.rst | 31 +++++ geoportal/c2cgeoportal_geoportal/__init__.py | 32 +---- geoportal/c2cgeoportal_geoportal/lib/oidc.py | 124 ++++++++++++++---- .../c2cgeoportal_geoportal/views/login.py | 21 +-- geoportal/tests/functional/test_oidc.py | 12 +- 5 files changed, 148 insertions(+), 72 deletions(-) diff --git a/doc/integrator/authentication_oidc.rst b/doc/integrator/authentication_oidc.rst index 593d64d809..7dcfb64115 100644 --- a/doc/integrator/authentication_oidc.rst +++ b/doc/integrator/authentication_oidc.rst @@ -88,3 +88,34 @@ Other options ``scopes``: The list of scopes to request, default is [``openid``, ``profile``, ``email``]. ``query_user_info``: If ``true``, the user info will be requested instead if using the ``id_token``, default is false. + +~~~~~ +Hooks +~~~~~ + +If you want to redefine the user creation process, you can use the hooks ``get_remember_from_user_info`` +and ``get_user_from_remember``. + +``get_remember_from_user_info``: This hook is called during the user is authentication. +The argument are the pyramid ``request``, the received ``user_info``, and the ``remember_object`` dictionary +to be filled and will be stored in the cookie. + +``get_user_from_remember``: This hook is called during the user is certification. +The argument are the pyramid ``request``, the received ``remember_object``, and the ``create_user`` boolean. +The return value is the user object ``User`` or ``DynamicUsed``. + +Full signatures: + +.. code:: python + + def get_remember_from_user_info(request: Request, user_info: Dict[str, Any], remember_object: OidcRememberObject) -> None: + + def get_user_from_remember(request: Request, remember_object: OidcRememberObject, create_user: bool) -> Union[User, DynamicUsed]: + +Configure the hooks in the project initialization: + +.. code:: python + + def includeme(config): + config.add_request_method(get_remember_from_user_info, name="get_remember_from_user_info") + config.add_request_method(get_user_from_remember, name="get_user_from_remember") diff --git a/geoportal/c2cgeoportal_geoportal/__init__.py b/geoportal/c2cgeoportal_geoportal/__init__.py index 7b684b9092..45f5df7240 100644 --- a/geoportal/c2cgeoportal_geoportal/__init__.py +++ b/geoportal/c2cgeoportal_geoportal/__init__.py @@ -49,7 +49,6 @@ import sqlalchemy.orm import zope.event.classhandler from c2cgeoform import translator -from c2cwsgiutils.broadcast import decorator from c2cwsgiutils.health_check import HealthCheck from c2cwsgiutils.prometheus import MemoryMapCollector from deform import Form @@ -57,15 +56,15 @@ from papyrus.renderers import GeoJSON from prometheus_client.core import REGISTRY from pyramid.config import Configurator -from pyramid.httpexceptions import HTTPBadRequest, HTTPException +from pyramid.httpexceptions import HTTPException from pyramid.path import AssetResolver from pyramid_mako import add_mako_renderer -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import joinedload import c2cgeoportal_commons.models import c2cgeoportal_geoportal.views from c2cgeoportal_commons.models import InvalidateCacheEvent -from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker +from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker, oidc from c2cgeoportal_geoportal.lib.cacheversion import version_cache_buster from c2cgeoportal_geoportal.lib.common_headers import Cache, set_common_headers from c2cgeoportal_geoportal.lib.i18n import available_locale_names @@ -327,7 +326,6 @@ def get_user_from_request( """ from c2cgeoportal_commons.models import DBSession # pylint: disable=import-outside-toplevel from c2cgeoportal_commons.models.static import User # pylint: disable=import-outside-toplevel - from c2cgeoportal_geoportal.lib import oidc # pylint: disable=import-outside-toplevel assert DBSession is not None @@ -357,28 +355,7 @@ def get_user_from_request( ) user_info = oidc.OidcRemember(request).remember(token_response, request.host) - if openid_connect_config.get("provide_roles", False) is True: - from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel - Role, - ) - - request.user_ = oidc.DynamicUser( - username=user_info["username"], - email=user_info["email"], - settings_role=( - DBSession.query(Role).filter_by(name=user_info["settings_role"]).first() - if user_info.get("settings_role") is not None - else None - ), - roles=[ - DBSession.query(Role).filter_by(name=role).one() - for role in user_info.get("roles", []) - ], - ) - else: - request.user_ = DBSession.query(User).filter_by(email=user_info["email"]).first() - for user in DBSession.query(User).all(): - _LOG.error(user.username) + request.user_ = request.get_user_from_reminder(user_info) else: # We know we will need the role object of the # user so we use joined loading @@ -527,6 +504,7 @@ def includeme(config: pyramid.config.Configurator) -> None: config.include("pyramid_mako") config.include("c2cwsgiutils.pyramid.includeme") + config.include(oidc.includeme) health_check = HealthCheck(config) config.registry["health_check"] = health_check diff --git a/geoportal/c2cgeoportal_geoportal/lib/oidc.py b/geoportal/c2cgeoportal_geoportal/lib/oidc.py index 955eb0b3db..9787b20c3a 100644 --- a/geoportal/c2cgeoportal_geoportal/lib/oidc.py +++ b/geoportal/c2cgeoportal_geoportal/lib/oidc.py @@ -28,7 +28,7 @@ import datetime import json import logging -from typing import NamedTuple, TypedDict +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypedDict, Union import pyramid.request import pyramid.response @@ -37,9 +37,11 @@ from pyramid.httpexceptions import HTTPBadRequest, HTTPInternalServerError, HTTPUnauthorized from pyramid.security import remember -from c2cgeoportal_commons.models import main from c2cgeoportal_geoportal.lib.caching import get_region +if TYPE_CHECKING: + from c2cgeoportal_commons.models import main, static + _LOG = logging.getLogger(__name__) _CACHE_REGION_OBJ = get_region("obj") @@ -52,8 +54,8 @@ class DynamicUser(NamedTuple): username: str email: str - settings_role: main.Role | None - roles: list[main.Role] + settings_role: Optional["main.Role"] + roles: list["main.Role"] @_CACHE_REGION_OBJ.cache_on_arguments() @@ -69,7 +71,6 @@ def get_oidc_client(request: pyramid.request.Request, host: str) -> simple_openi if openid_connect.get("enabled", False) is not True: raise HTTPBadRequest("OpenID Connect not enabled") - _LOG.info(openid_connect) return simple_openid_connect.client.OpenidClient.from_issuer_url( url=openid_connect["url"], authentication_redirect_uri=request.route_url("oidc_callback"), @@ -94,6 +95,91 @@ class OidcRememberObject(TypedDict): roles: list[str] +def get_remember_from_user_info( + request: pyramid.request.Request, user_info: dict[str, Any], remember_object: OidcRememberObject +) -> None: + """ + Fill the remember object from the user info. + + The remember object will be stored in a cookie to remember the user. + + :param user_info: The user info from the ID token or from the user info view according to the `query_user_info` configuration. + :param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`, + the corresponding field from `user_info` can be configured in `user_info_fields`. + :param settings: The OpenID Connect configuration. + """ + settings_fields = ( + request.registry.settings.get("authentication", {}) + .get("openid_connect", {}) + .get("user_info_fields", {}) + ) + + for field_, default_field in ( + ("username", "name"), + ("email", "email"), + ("settings_role", None), + ("roles", None), + ): + user_info_field = settings_fields.get(field_, default_field) + if user_info_field is not None: + if user_info_field not in user_info: + _LOG.error( + "Field '%s' not found in user info, available: %s.", + user_info_field, + ", ".join(user_info.keys()), + ) + raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.") + remember_object[field_] = user_info[user_info_field] # type: ignore[literal-required] + + +def get_user_from_remember( + request: pyramid.request.Request, remember_object: OidcRememberObject, create_user: bool = False +) -> Union["static.User", DynamicUser] | None: + """ + Create a user from the remember object filled from `get_remember_from_user_info`. + + :param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`. + :param settings: The OpenID Connect configuration. + :param create_user: If the user should be created if it does not exist. + """ + from c2cgeoportal_commons import models # pylint: disable=import-outside-toplevel + from c2cgeoportal_commons.models import main, static # pylint: disable=import-outside-toplevel + + assert models.DBSession is not None + + user: static.User | DynamicUser | None + username = remember_object["username"] + assert username is not None + email = remember_object["email"] + assert email is not None + + provide_roles = ( + request.registry.settings.get("authentication", {}) + .get("openid_connect", {}) + .get("provide_roles", False) + ) + if provide_roles is False: + user = models.DBSession.query(static.User).filter_by(email=email).one_or_none() + if user is None and create_user is True: + user = static.User(username=username, email=email) + models.DBSession.add(user) + else: + user = DynamicUser( + username=username, + email=email, + settings_role=( + models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first() + if remember_object.get("settings_role") is not None + else None + ), + roles=[ + models.DBSession.query(main.Role).filter_by(name=role).one() + for role in remember_object.get("roles", []) + ], + ) + return user + + class OidcRemember: """ Build the abject that we want to remember in the cookie. @@ -148,7 +234,6 @@ def remember( "settings_role": None, "roles": [], } - settings_fields = openid_connect.get("user_info_fields", {}) client = get_oidc_client(self.request, self.request.host) if openid_connect.get("query_user_info", False) is True: @@ -172,24 +257,15 @@ def remember( ), ) - for field_, default_field in ( - ("username", "name"), - ("email", "email"), - ("settings_role", None), - ("roles", None), - ): - user_info_field = settings_fields.get(field_, default_field) - if user_info_field is not None: - user_info_dict = user_info.dict() - if user_info_field not in user_info_dict: - _LOG.error( - "Field '%s' not found in user info, available: %s.", - user_info_field, - ", ".join(user_info_dict.keys()), - ) - raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.") - remember_object[field_] = user_info_dict[user_info_field] # type: ignore[literal-required] - + self.request.get_remember_from_user_info(user_info.dict(), remember_object) self.request.response.headers.extend(remember(self.request, json.dumps(remember_object))) return remember_object + + +def includeme(config: pyramid.config.Configurator) -> None: + """ + Pyramid includeme function. + """ + config.add_request_method(get_remember_from_user_info, name="get_remember_from_user_info") + config.add_request_method(get_user_from_remember, name="get_user_from_remember") diff --git a/geoportal/c2cgeoportal_geoportal/views/login.py b/geoportal/c2cgeoportal_geoportal/views/login.py index 49645bfd4d..558fa45e47 100644 --- a/geoportal/c2cgeoportal_geoportal/views/login.py +++ b/geoportal/c2cgeoportal_geoportal/views/login.py @@ -644,26 +644,7 @@ def oidc_callback(self) -> pyramid.response.Response: remember_object = oidc.OidcRemember(self.request).remember(token_response, self.request.host) - user: static.User | oidc.DynamicUser | None - if self.authentication_settings.get("openid_connect", {}).get("provide_roles", False) is False: - user = models.DBSession.query(static.User).filter_by(email=remember_object["email"]).one_or_none() - if user is None: - user = static.User(username=remember_object["username"], email=remember_object["email"]) - models.DBSession.add(user) - else: - user = oidc.DynamicUser( - username=remember_object["username"], - email=remember_object["email"], - settings_role=( - models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first() - if remember_object.get("settings_role") is not None - else None - ), - roles=[ - models.DBSession.query(main.Role).filter_by(name=role).one() - for role in remember_object.get("roles", []) - ], - ) + user: static.User | oidc.DynamicUser | None = self.request.get_user_from_remember(remember_object) assert user is not None self.request.user_ = user diff --git a/geoportal/tests/functional/test_oidc.py b/geoportal/tests/functional/test_oidc.py index accfb1291f..21e7e8d688 100644 --- a/geoportal/tests/functional/test_oidc.py +++ b/geoportal/tests/functional/test_oidc.py @@ -1,12 +1,12 @@ import base64 import re +import types import urllib.parse from http.client import responses from unittest import TestCase import jwt import responses -from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from pyramid import testing from tests.functional import cleanup_db, create_dummy_request @@ -14,6 +14,8 @@ from tests.functional import setup_db from tests.functional import teardown_common as teardown_module # noqa, pylint: disable=unused-import +from c2cgeoportal_geoportal.lib import oidc + _OIDC_CONFIGURATION = { "issuer": "https://sso.example.com", "authorization_endpoint": "https://sso.example.com/authorize", @@ -41,6 +43,11 @@ } +def includeme(request): + request.get_remember_from_user_info = types.MethodType(oidc.get_remember_from_user_info, request) + request.get_user_from_remember = types.MethodType(oidc.get_user_from_remember, request) + + class TestLogin(TestCase): def setUp(self): setup_db() @@ -66,6 +73,7 @@ def test_login(self): }, params={"came_from": "/came_from"}, ) + includeme(request) responses.get("https://sso.example.com/.well-known/openid-configuration", json=_OIDC_CONFIGURATION) responses.get("https://sso.example.com/jwks", json=_OIDC_KEYS) @@ -106,6 +114,7 @@ def test_callback(self): "authentication": { "openid_connect": { "enabled": True, + "provide_roles": True, "url": "https://sso.example.com", "client_id": "client_id_123", } @@ -118,6 +127,7 @@ def test_callback(self): "code_challenge": "code_challenge", }, ) + includeme(request) responses.get("https://sso.example.com/.well-known/openid-configuration", json=_OIDC_CONFIGURATION) responses.get("https://sso.example.com/jwks", json=_OIDC_KEYS) responses.post(