diff --git a/docs_src/tutorial/lifespan.py b/docs_src/tutorial/lifespan.py index e8bd5dfe..0232b350 100644 --- a/docs_src/tutorial/lifespan.py +++ b/docs_src/tutorial/lifespan.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import AsyncGenerator @@ -13,6 +14,7 @@ class Config: ConfigDep = Annotated[Config, Dependant(scope="app")] +@asynccontextmanager async def lifespan(config: ConfigDep) -> AsyncGenerator[None, None]: print(config.token) yield diff --git a/pyproject.toml b/pyproject.toml index ccc05197..84c2602f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "xpresso" -version = "0.7.3" +version = "0.8.0" description = "A developer centric, performant Python web framework" authors = ["Adrian Garcia Badaracco "] readme = "README.md" diff --git a/tests/test_dependency_injection.py b/tests/test_dependency_injection.py index b83c235a..1c563d99 100644 --- a/tests/test_dependency_injection.py +++ b/tests/test_dependency_injection.py @@ -1,14 +1,11 @@ -import sys - -if sys.version_info < (3, 9): - from typing_extensions import Annotated -else: - from typing import Annotated +from contextlib import asynccontextmanager +from typing import AsyncIterator from starlette.responses import Response from starlette.testclient import TestClient from xpresso import App, Dependant, Operation, Path +from xpresso.typing import Annotated def test_router_route_dependencies() -> None: @@ -44,8 +41,12 @@ def test_lifespan_dependencies() -> None: class Test: ... - async def lifespan(t: Annotated[Test, Dependant(scope="app")]) -> None: + @asynccontextmanager + async def lifespan( + t: Annotated[Test, Dependant(scope="app")] + ) -> AsyncIterator[None]: app.state.t = t # type: ignore[has-type] + yield async def endpoint(t: Annotated[Test, Dependant(scope="app")]) -> Response: assert app.state.t is t # type: ignore[has-type] diff --git a/tests/test_docs/tutorial/routing/test_tutorial002.py b/tests/test_docs/tutorial/routing/test_tutorial002.py index f5973c1f..ab5f0cfd 100644 --- a/tests/test_docs/tutorial/routing/test_tutorial002.py +++ b/tests/test_docs/tutorial/routing/test_tutorial002.py @@ -12,7 +12,7 @@ def test_openapi() -> None: "/v1/items": { "get": { "responses": {"404": {"description": "Item not found"}}, - "tags": ["read", "v1", "items"], + "tags": ["v1", "items", "read"], "summary": "List all items", "description": "The **items** operation", "deprecated": True, @@ -32,7 +32,7 @@ def test_openapi() -> None: }, }, }, - "tags": ["write", "v1", "items"], + "tags": ["v1", "items", "write"], "description": "Documentation from docstrings!\n You can use any valid markdown, for example lists:\n\n - Point 1\n - Point 2\n ", "requestBody": { "content": { diff --git a/tests/test_exception_handlers.py b/tests/test_exception_handlers.py new file mode 100644 index 00000000..1ecc08c6 --- /dev/null +++ b/tests/test_exception_handlers.py @@ -0,0 +1,42 @@ +from xpresso import App, HTTPException, Path, Request +from xpresso.responses import JSONResponse +from xpresso.testclient import TestClient + + +def test_override_base_error_handler() -> None: + async def custom_server_error_from_exception(request: Request, exc: Exception): + return JSONResponse( + {"detail": "Custom Server Error from Exception"}, status_code=500 + ) + + async def raise_exception() -> None: + raise Exception + + async def custom_server_error_from_500(request: Request, exc: Exception): + return JSONResponse( + {"detail": "Custom Server Error from HTTPException(500)"}, status_code=500 + ) + + async def raise_500() -> None: + raise HTTPException(500) + + app = App( + routes=[ + Path("/raise-exception", get=raise_exception), + Path("/raise-500", get=raise_500), + ], + exception_handlers={ + Exception: custom_server_error_from_exception, + 500: custom_server_error_from_500, + }, + ) + + client = TestClient(app) + + resp = client.get("/raise-exception") + assert resp.status_code == 500, resp.content + assert resp.json() == {"detail": "Custom Server Error from Exception"} + + resp = client.get("/raise-500") + assert resp.status_code == 500, resp.content + assert resp.json() == {"detail": "Custom Server Error from HTTPException(500)"} diff --git a/tests/test_lifespans.py b/tests/test_lifespans.py new file mode 100644 index 00000000..4c59967a --- /dev/null +++ b/tests/test_lifespans.py @@ -0,0 +1,38 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator, List + +from xpresso import App, Dependant, Router +from xpresso.routing.mount import Mount +from xpresso.testclient import TestClient + + +def test_lifespan_mounted_app() -> None: + class Counter(List[int]): + pass + + @asynccontextmanager + async def lifespan(counter: Counter) -> AsyncIterator[None]: + counter.append(1) + yield + + counter = Counter() + + inner_app = App(lifespan=lifespan) + inner_app.container.register_by_type( + Dependant(lambda: counter, scope="app"), Counter + ) + + app = App( + routes=[ + Mount("/mounted-app", app=inner_app), + Mount("/mounted-router", app=Router([], lifespan=lifespan)), + ], + lifespan=lifespan, + ) + + app.container.register_by_type(Dependant(lambda: counter, scope="app"), Counter) + + with TestClient(app): + pass + + assert counter == [1, 1, 1] diff --git a/tests/test_routing/test_mounts.py b/tests/test_routing/test_mounts.py index 57e38ccd..a4437726 100644 --- a/tests/test_routing/test_mounts.py +++ b/tests/test_routing/test_mounts.py @@ -1,7 +1,9 @@ """Tests for experimental OpenAPI inspired routing""" from typing import Any, Dict -from xpresso import App, FromPath, Path +from di import BaseContainer + +from xpresso import App, Dependant, FromPath, Path from xpresso.routing.mount import Mount from xpresso.testclient import TestClient @@ -126,7 +128,7 @@ def test_openapi_routing_for_mounted_path() -> None: assert resp.json() == expected_openapi -def test_xpresso_app_as_app_param_to_mount_routing() -> None: +def test_mounted_xpresso_app_routing() -> None: # not a use case we advertise # but we want to know what the behavior is app = App( @@ -152,7 +154,7 @@ def test_xpresso_app_as_app_param_to_mount_routing() -> None: assert resp.json() == 124 -def test_xpresso_app_as_app_param_to_mount_openapi() -> None: +def test_mounted_xpresso_app_openapi() -> None: # not a use case we advertise # but we want to know what the behavior is app = App( @@ -176,9 +178,170 @@ def test_xpresso_app_as_app_param_to_mount_openapi() -> None: expected_openapi: Dict[str, Any] = { "openapi": "3.0.3", "info": {"title": "API", "version": "0.1.0"}, - "paths": {}, + "paths": { + "/mount/{number}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {"title": "Response", "type": "integer"} + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "parameters": [ + { + "required": True, + "style": "simple", + "explode": False, + "schema": {"title": "Number", "type": "integer"}, + "name": "number", + "in": "path", + } + ], + } + } + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": { + "oneOf": [{"type": "string"}, {"type": "integer"}] + }, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, } resp = client.get("/openapi.json") assert resp.status_code == 200, resp.content assert resp.json() == expected_openapi + + +def test_mounted_xpresso_app_dependencies_isolated_containers() -> None: + # not a use case we advertise + # but we want to know what the behavior is + + class Thing: + def __init__(self, value: str = "default") -> None: + self.value = value + + async def endpoint(thing: Thing) -> str: + return thing.value + + inner_app = App( + routes=[ + Path( + path="/", + get=endpoint, + ) + ], + ) + + app = App( + routes=[ + Mount( + path="/mount", + app=inner_app, + ), + Path("/top-level", get=endpoint), + ] + ) + + app.container.register_by_type( + Dependant(lambda: Thing("injected")), + Thing, + ) + + client = TestClient(app) + + resp = client.get("/top-level") + assert resp.status_code == 200, resp.content + assert resp.json() == "injected" + + resp = client.get("/mount") + assert resp.status_code == 200, resp.content + assert resp.json() == "default" + + +def test_mounted_xpresso_app_dependencies_shared_containers() -> None: + # not a use case we advertise + # but we want to know what the behavior is + + class Thing: + def __init__(self, value: str = "default") -> None: + self.value = value + + async def endpoint(thing: Thing) -> str: + return thing.value + + container = BaseContainer(scopes=("app", "connection", "operation")) + container.register_by_type( + Dependant(lambda: Thing("injected")), + Thing, + ) + + inner_app = App( + routes=[ + Path( + path="/", + get=endpoint, + ) + ], + container=container, + ) + + app = App( + routes=[ + Mount( + path="/mount", + app=inner_app, + ), + Path("/top-level", get=endpoint), + ], + container=container, + ) + + client = TestClient(app) + + resp = client.get("/top-level") + assert resp.status_code == 200, resp.content + assert resp.json() == "injected" + + resp = client.get("/mount") + assert resp.status_code == 200, resp.content + assert resp.json() == "injected" diff --git a/tests/test_routing/test_router.py b/tests/test_routing/test_router.py new file mode 100644 index 00000000..206414f1 --- /dev/null +++ b/tests/test_routing/test_router.py @@ -0,0 +1,121 @@ +from typing import Any, Dict + +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint + +from xpresso import App, Path, Request, Response, Router +from xpresso.routing.mount import Mount +from xpresso.testclient import TestClient + + +def test_router_middleware() -> None: + async def endpoint() -> None: + ... + + class AddCustomHeaderMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + resp = await call_next(request) + resp.headers["X-Custom"] = "123" + return resp + + app = App( + routes=[ + Mount( + "/with-middleware", + app=Router( + routes=[ + Path( + "/", + get=endpoint, + ) + ], + middleware=[Middleware(AddCustomHeaderMiddleware)], + ), + ), + Mount( + "/without-middleware", + app=Router( + routes=[ + Path( + "/", + get=endpoint, + ) + ] + ), + ), + ] + ) + + client = TestClient(app) + + resp = client.get("/with-middleware/") + assert resp.status_code == 200, resp.content + assert resp.headers["X-Custom"] == "123" + + resp = client.get("/without-middleware/") + assert resp.status_code == 200, resp.content + assert "X-Custom" not in resp.headers + + +def test_router_middleware_modify_path() -> None: + async def endpoint() -> None: + ... + + class RerouteMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + request.scope["path"] = request.scope["path"].replace("bad", "good") + return await call_next(request) + + app = App( + routes=[ + Mount( + "/", + app=Router( + routes=[ + Path( + "/good", + get=endpoint, + ) + ], + middleware=[Middleware(RerouteMiddleware)], + ), + ), + ] + ) + + client = TestClient(app) + + resp = client.get("/bad") + assert resp.status_code == 200, resp.content + + resp = client.get("/very-bad") + assert resp.status_code == 404, resp.content + + +def test_exclude_from_schema() -> None: + app = App( + routes=[ + Mount( + "/mount", + app=Router( + routes=[Path("/test", get=lambda: None)], include_in_schema=False + ), + ) + ] + ) + + expected_openapi_json: Dict[str, Any] = { + "openapi": "3.0.3", + "info": {"title": "API", "version": "0.1.0"}, + "paths": {}, + } + + client = TestClient(app) + + resp = client.get("/openapi.json") + assert resp.status_code == 200, resp.content + assert resp.json() == expected_openapi_json diff --git a/xpresso/_utils/deprecation.py b/xpresso/_utils/deprecation.py new file mode 100644 index 00000000..eb874141 --- /dev/null +++ b/xpresso/_utils/deprecation.py @@ -0,0 +1,12 @@ +import typing + + +def not_supported(method: str) -> typing.Callable[..., typing.Any]: + """Marks a method as not supported + Used to hard-deprecate things from Starlette + """ + + def raise_error(*args: typing.Any, **kwargs: typing.Any) -> typing.NoReturn: + raise NotImplementedError(f"Use of {method} is not supported") + + return raise_error diff --git a/xpresso/_utils/routing.py b/xpresso/_utils/routing.py index 2af5f159..f9e6e282 100644 --- a/xpresso/_utils/routing.py +++ b/xpresso/_utils/routing.py @@ -1,58 +1,72 @@ +import sys import typing from dataclasses import dataclass +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + from starlette.routing import BaseRoute, Mount from starlette.routing import Router as StarletteRouter -from xpresso.responses import Responses from xpresso.routing.pathitem import Path -from xpresso.routing.router import Router as XpressoRouter +from xpresso.routing.websockets import WebSocketRoute + + +class App(Protocol): + @property + def router(self) -> StarletteRouter: + ... + + +AppType = typing.TypeVar("AppType", bound=App) @dataclass(frozen=True) -class VisitedRoute: +class VisitedRoute(typing.Generic[AppType]): path: str - routers: typing.List[StarletteRouter] + nodes: typing.List[typing.Union[StarletteRouter, AppType]] route: BaseRoute - tags: typing.List[str] - responses: Responses def visit_routes( - routers: typing.List[StarletteRouter], - path: typing.Optional[str] = None, - tags: typing.Optional[typing.List[str]] = None, - responses: typing.Optional[Responses] = None, -) -> typing.Generator[VisitedRoute, None, None]: - path = path or "" - tags = tags or [] - responses = responses or {} - router = next(iter(reversed(routers)), None) - assert router is not None - for route in typing.cast(typing.Iterable[BaseRoute], router.routes): - if isinstance(route, Mount) and isinstance(route.app, StarletteRouter): - child_tags = tags - child_responses = responses - if isinstance(route.app, XpressoRouter): - child_tags = child_tags + route.app.tags - child_responses = {**child_responses, **route.app.responses} - yield from visit_routes( - routers=routers + [route.app], - path=path + route.path, - tags=child_tags, - responses=child_responses, - ) - elif hasattr(route, "path"): - route_path: str = route.path # type: ignore - child_tags = tags - child_responses = responses - if isinstance(route, Path): - child_tags = child_tags + route.tags - child_responses = {**child_responses, **route.responses} + app_type: typing.Type[AppType], + router: StarletteRouter, + nodes: typing.List[typing.Union[StarletteRouter, AppType]], + path: str, +) -> typing.Generator[VisitedRoute[AppType], None, None]: + for route in typing.cast(typing.Iterable[BaseRoute], router.routes): # type: ignore # for Pylance + if isinstance(route, Mount): + app: typing.Any = route.app + mount_path: str = route.path # type: ignore # for Pylance + if isinstance(app, StarletteRouter): + yield VisitedRoute( + path=path, + nodes=nodes + [app], + route=route, + ) + yield from visit_routes( + app_type=app_type, + router=app, + nodes=nodes + [app], + path=path + mount_path, + ) + elif isinstance(app, app_type): + yield VisitedRoute( + path=path, + nodes=nodes + [app, app.router], + route=route, + ) + yield from visit_routes( + app_type=app_type, + router=app.router, + nodes=nodes + [app, app.router], + path=path + mount_path, + ) + elif isinstance(route, (Path, WebSocketRoute)): yield VisitedRoute( - path=path + route_path, - routers=routers, + path=path + route.path, + nodes=nodes, route=route, - tags=child_tags, - responses=child_responses, ) diff --git a/xpresso/applications.py b/xpresso/applications.py index 3d985178..d92becdb 100644 --- a/xpresso/applications.py +++ b/xpresso/applications.py @@ -1,11 +1,10 @@ +import inspect import typing from contextlib import asynccontextmanager import starlette.types -from di import AsyncExecutor, BaseContainer +from di import AsyncExecutor, BaseContainer, JoinedDependant from di.api.dependencies import DependantBase -from di.api.providers import DependencyProviderType -from starlette.applications import Starlette from starlette.datastructures import State from starlette.middleware import Middleware from starlette.middleware.errors import ServerErrorMiddleware @@ -32,19 +31,61 @@ from xpresso.routing.websockets import WebSocketRoute from xpresso.security._dependants import Security -ExceptionHandler = typing.Callable[[Request, typing.Type[BaseException]], Response] +ExceptionHandler = typing.Callable[ + [Request, Exception], typing.Union[Response, typing.Awaitable[Response]] +] +ExceptionHandlers = typing.Mapping[ + typing.Union[typing.Type[Exception], int], ExceptionHandler +] -class App(Starlette): +def _include_error_middleware( + debug: bool, + user_middleware: typing.Iterable[Middleware], + exception_handlers: ExceptionHandlers, +) -> typing.Sequence[Middleware]: + # user's exception handlers come last so that they can override + # the default exception handlers + exception_handlers = { + RequestValidationError: validation_exception_handler, + HTTPException: http_exception_handler, + **exception_handlers, + } + + error_handler = None + for key, value in exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + return ( + Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug), + *user_middleware, + Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug), + ) + + +def _wrap_lifespan_as_async_generator( + lifespan: typing.Callable[..., typing.AsyncContextManager[None]] +) -> typing.Callable[..., typing.AsyncIterator[None]]: + async def gen( + *args: typing.Any, **kwargs: typing.Any + ) -> typing.AsyncIterator[None]: + async with lifespan(*args, **kwargs): + yield + + sig = inspect.signature(gen) + sig = sig.replace(parameters=list(inspect.signature(lifespan).parameters.values())) + setattr(gen, "__signature__", sig) + + return gen + + +class App: router: Router - middleware_stack: starlette.types.ASGIApp - openapi: typing.Optional[openapi_models.OpenAPI] = None - _debug: bool + openapi: typing.Optional[openapi_models.OpenAPI] state: State - exception_handlers: typing.Mapping[ - typing.Union[int, typing.Type[Exception]], ExceptionHandler - ] - user_middleware: typing.Sequence[Middleware] container: BaseContainer def __init__( @@ -55,13 +96,11 @@ def __init__( dependencies: typing.Optional[typing.List[Dependant]] = None, debug: bool = False, middleware: typing.Optional[typing.Sequence[Middleware]] = None, - exception_handlers: typing.Optional[ - typing.Dict[ - typing.Union[int, typing.Type[Exception]], - ExceptionHandler, - ] + exception_handlers: typing.Optional[ExceptionHandlers] = None, + lifespan: typing.Optional[ + typing.Callable[..., typing.AsyncContextManager[None]] ] = None, - lifespan: typing.Optional[DependencyProviderType[None]] = None, + include_in_schema: bool = True, openapi_version: str = "3.0.3", title: str = "API", description: typing.Optional[str] = None, @@ -70,19 +109,6 @@ def __init__( docs_url: typing.Optional[str] = "/docs", servers: typing.Optional[typing.Iterable[openapi_models.Server]] = None, ) -> None: - routes = list(routes or []) - routes.extend( - self._get_doc_routes( - openapi_url=openapi_url, - docs_url=docs_url, - ) - ) - self._debug = debug - self.state = State() - self.exception_handlers = ( - {} if exception_handlers is None else dict(exception_handlers) - ) - self.container = container or BaseContainer( scopes=("app", "connection", "operation") ) @@ -90,29 +116,57 @@ def __init__( self._setup_run = False @asynccontextmanager - async def lifespan_ctx(app: Starlette) -> typing.AsyncGenerator[None, None]: - self._setup() + async def lifespan_ctx(*args: typing.Any) -> typing.AsyncIterator[None]: + lifespans = self._setup() self._setup_run = True original_container = self.container async with self.container.enter_scope("app") as container: self.container = container if lifespan is not None: - await container.execute_async( - self.container.solve(Dependant(call=lifespan, scope="app")), - executor=AsyncExecutor(), + dep = Dependant( + _wrap_lifespan_as_async_generator(lifespan), scope="app" + ) + else: + dep = Dependant(lambda: None, scope="app") + solved = self.container.solve( + JoinedDependant( + dep, + siblings=[ + Dependant(lifespan, scope="app") for lifespan in lifespans + ], ) + ) try: + await container.execute_async(solved, executor=AsyncExecutor()) yield finally: # make this cm reentrant for testing purposes self.container = original_container self._setup_run = False - self.router = Router(routes, lifespan=lifespan_ctx, dependencies=dependencies) - self.user_middleware = [] if middleware is None else list(middleware) - self.middleware_stack = self.build_middleware_stack() # type: ignore - self.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore - self.add_exception_handler(HTTPException, http_exception_handler) # type: ignore + self._debug = debug + self.state = State() + + routes = list(routes or []) + routes.extend( + self._get_doc_routes( + openapi_url=openapi_url, + docs_url=docs_url, + ) + ) + middleware = _include_error_middleware( + debug=debug, + user_middleware=middleware or (), + exception_handlers=exception_handlers or {}, + ) + self.router = Router( + routes, + dependencies=dependencies, + middleware=middleware, + include_in_schema=include_in_schema, + lifespan=lifespan_ctx, + ) + self.openapi_version = openapi_version self.openapi_info = openapi_models.Info( title=title, @@ -120,28 +174,7 @@ async def lifespan_ctx(app: Starlette) -> typing.AsyncGenerator[None, None]: description=description, ) self.servers = servers - - def build_middleware_stack(self) -> starlette.types.ASGIApp: - debug = self.debug - error_handler = None - exception_handlers = {} - - for key, value in self.exception_handlers.items(): - if key in (500, Exception): - error_handler = value - else: - exception_handlers[key] = value - - middleware = ( - Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug), - *self.user_middleware, - Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug), - ) - - app = self.router - for cls, options in reversed(middleware): - app = cls(app=app, **options) - return app + self.openapi = None async def __call__( self, @@ -149,31 +182,42 @@ async def __call__( receive: starlette.types.Receive, send: starlette.types.Send, ) -> None: - self._setup() - if scope["type"] == "http" or scope["type"] == "websocket": + scope["app"] = self + scope_type = scope["type"] + if scope_type == "http" or scope_type == "websocket": + if not self._setup_run: + self._setup() extensions = scope.get("extensions", None) or {} scope["extensions"] = extensions - xpresso_scope = extensions.get("xpresso", None) - if xpresso_scope is None: - async with self.container.enter_scope("connection") as container: - xpresso_asgi_extension: XpressoASGIExtension = { - "container": container, - "response_sent": False, - } - extensions["xpresso"] = xpresso_asgi_extension - await super().__call__(scope, receive, send) - xpresso_asgi_extension["response_sent"] = True - return - await super().__call__(scope, receive, send) - - def _setup(self) -> None: - if self._setup_run: + xpresso_asgi_extension: XpressoASGIExtension = extensions.get("xpresso", None) or {} # type: ignore[assignment] + extensions["xpresso"] = xpresso_asgi_extension + async with self.container.enter_scope("connection") as container: + xpresso_asgi_extension["response_sent"] = False + xpresso_asgi_extension["container"] = container + await self.router(scope, receive, send) + xpresso_asgi_extension["response_sent"] = True return - for route in visit_routes([self.router]): + else: # lifespan + await self.router(scope, receive, send) + + def _setup(self) -> typing.List[typing.Callable[..., typing.AsyncIterator[None]]]: + lifespans: typing.List[typing.Callable[..., typing.AsyncIterator[None]]] = [] + for route in visit_routes( + app_type=App, router=self.router, nodes=[self, self.router], path="" + ): dependencies: typing.List[DependantBase[typing.Any]] = [] - for router in route.routers: - if isinstance(router, Router): - dependencies.extend(router.dependencies) + for node in route.nodes: + if isinstance(node, Router): + dependencies.extend(node.dependencies) + if node is not self.router: # avoid circul lifespan calls + lifespan = typing.cast( + typing.Callable[..., typing.AsyncContextManager[None]], + node.lifespan_context, # type: ignore # for Pylance + ) + if lifespan is not None: + lifespans.append( + _wrap_lifespan_as_async_generator(lifespan) + ) if isinstance(route.route, Path): for operation in route.route.operations.values(): operation.solve( @@ -184,7 +228,7 @@ def _setup(self) -> None: ], container=self.container, ) - if isinstance(route.route, WebSocketRoute): + elif isinstance(route.route, WebSocketRoute): route.route.solve( dependencies=[ *dependencies, @@ -192,19 +236,20 @@ def _setup(self) -> None: ], container=self.container, ) + return lifespans async def get_openapi(self) -> openapi_models.OpenAPI: return genrate_openapi( + visitor=visit_routes(app_type=App, router=self.router, nodes=[self, self.router], path=""), # type: ignore # for Pylance version=self.openapi_version, info=self.openapi_info, servers=self.servers, - router=self.router, security_models=await self.gather_security_models(), ) async def gather_security_models(self) -> SecurityModels: security_dependants: typing.List[Security] = [] - for route in visit_routes([self.router]): + for route in visit_routes(app_type=App, router=self.router, nodes=[self, self.router], path=""): # type: ignore[misc] if isinstance(route.route, Path): for operation in route.route.operations.values(): dependant = operation.dependant @@ -247,7 +292,7 @@ async def openapi(req: Request) -> JSONResponse: openapi_url = openapi_url async def swagger_ui_html(req: Request) -> HTMLResponse: - root_path: str = req.scope.get("root_path", "").rstrip("/") + root_path: str = req.scope.get("root_path", "").rstrip("/") # type: ignore # for Pylance full_openapi_url = root_path + openapi_url # type: ignore[operator] return get_swagger_ui_html( openapi_url=full_openapi_url, diff --git a/xpresso/exception_handlers.py b/xpresso/exception_handlers.py index e9e45423..ed164935 100644 --- a/xpresso/exception_handlers.py +++ b/xpresso/exception_handlers.py @@ -8,7 +8,8 @@ encoder = JsonableEncoder() -async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: +async def http_exception_handler(request: Request, exc: Exception) -> JSONResponse: + assert isinstance(exc, HTTPException) headers = getattr(exc, "headers", None) if headers: return JSONResponse( @@ -19,8 +20,9 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe async def validation_exception_handler( - request: Request, exc: RequestValidationError + request: Request, exc: Exception ) -> JSONResponse: + assert isinstance(exc, RequestValidationError) return JSONResponse( encoder({"detail": exc.errors()}), status_code=exc.status_code, diff --git a/xpresso/openapi/_builder.py b/xpresso/openapi/_builder.py index 593792f8..4e3ba511 100644 --- a/xpresso/openapi/_builder.py +++ b/xpresso/openapi/_builder.py @@ -6,6 +6,7 @@ Any, Callable, Dict, + Generator, Iterable, List, Mapping, @@ -14,6 +15,7 @@ Sequence, Type, Union, + cast, ) if sys.version_info < (3, 9): @@ -28,7 +30,7 @@ from starlette.responses import Response from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY -from xpresso._utils.routing import visit_routes +from xpresso._utils.routing import VisitedRoute from xpresso.binders import dependants as binder_dependants from xpresso.openapi import models from xpresso.openapi.constants import REF_PREFIX @@ -293,11 +295,11 @@ def get_operation( model_name_map: ModelNameMap, components: Dict[str, Any], security_models: Mapping[Security, SecurityBase], - parent_tags: List[str], - parent_responses: Responses, + tags: List[str], + response_specs: Responses, ) -> models.Operation: data: Dict[str, Any] = { - "tags": [*route.tags, *parent_tags] or None, + "tags": tags or None, "summary": route.summary, "description": route.description, "deprecated": route.deprecated, @@ -344,7 +346,6 @@ def get_operation( ], security_schemes, ) - response_specs = {**parent_responses, **route.responses} data["responses"] = get_responses( route, response_specs=response_specs, @@ -379,18 +380,33 @@ def get_operation( def get_paths_items( - router: Router, + visitor: Generator[VisitedRoute[Any], None, None], model_name_map: ModelNameMap, components: Dict[str, Any], security_models: Mapping[Security, SecurityBase], ) -> Dict[str, models.PathItem]: paths: Dict[str, models.PathItem] = {} - for visited_route in visit_routes([router]): + for visited_route in visitor: if isinstance(visited_route.route, Path): - if not visited_route.route.include_in_schema: + path_item = visited_route.route + if not path_item.include_in_schema: + continue + tags: List[str] = [] + responses = dict(cast(Responses, {})) + include_in_schema = True + for node in visited_route.nodes: + if isinstance(node, Router): + if not node.include_in_schema: + include_in_schema = False + break + responses.update(node.responses) + tags.extend(node.tags) + if not include_in_schema: continue + tags.extend(path_item.tags) + responses.update(path_item.responses) operations: Dict[str, models.Operation] = {} - for method, operation in visited_route.route.operations.items(): + for method, operation in path_item.operations.items(): if not operation.include_in_schema: continue operations[method.lower()] = get_operation( @@ -398,8 +414,8 @@ def get_paths_items( model_name_map=model_name_map, components=components, security_models=security_models, - parent_tags=visited_route.tags, - parent_responses=visited_route.responses, + tags=tags + operation.tags, + response_specs={**responses, **operation.responses}, ) paths[visited_route.path] = models.PathItem( description=visited_route.route.description, @@ -411,15 +427,15 @@ def get_paths_items( def genrate_openapi( + visitor: Generator[VisitedRoute[Any], None, None], version: str, info: models.Info, servers: Optional[Iterable[models.Server]], - router: Router, security_models: Mapping[Security, SecurityBase], ) -> models.OpenAPI: model_name_map: ModelNameMap = {} components: Dict[str, Any] = {} - paths = get_paths_items(router, model_name_map, components, security_models) + paths = get_paths_items(visitor, model_name_map, components, security_models) return models.OpenAPI( openapi=version, info=info, diff --git a/xpresso/routing/router.py b/xpresso/routing/router.py index 8ce07c3c..7d71b19a 100644 --- a/xpresso/routing/router.py +++ b/xpresso/routing/router.py @@ -1,39 +1,35 @@ import typing -from starlette.applications import Starlette +import starlette.middleware from starlette.routing import BaseRoute from starlette.routing import Router as StarletteRouter -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send +from xpresso._utils.deprecation import not_supported from xpresso.dependencies.models import Dependant from xpresso.responses import Responses -def _not_supported(method: str) -> typing.Callable[..., typing.Any]: - def raise_error(*args: typing.Any, **kwargs: typing.Any) -> typing.NoReturn: - raise NotImplementedError( - f"Use of Router.{method} is deprecated." - " Use Router(routes=[...]) instead." - ) - - return raise_error - - class Router(StarletteRouter): routes: typing.List[BaseRoute] + _app: ASGIApp def __init__( self, routes: typing.Sequence[BaseRoute], *, - redirect_slashes: bool = True, - default: typing.Optional[ASGIApp] = None, + middleware: typing.Optional[ + typing.Sequence[starlette.middleware.Middleware] + ] = None, lifespan: typing.Optional[ - typing.Callable[[Starlette], typing.AsyncContextManager[None]] + typing.Callable[..., typing.AsyncContextManager[None]] ] = None, + redirect_slashes: bool = True, + default: typing.Optional[ASGIApp] = None, dependencies: typing.Optional[typing.List[Dependant]] = None, tags: typing.Optional[typing.List[str]] = None, responses: typing.Optional[Responses] = None, + include_in_schema: bool = True, ) -> None: super().__init__( # type: ignore routes=list(routes), @@ -44,12 +40,29 @@ def __init__( self.dependencies = list(dependencies or []) self.tags = list(tags or []) self.responses = dict(responses or {}) + self.include_in_schema = include_in_schema + self._app = super().__call__ # type: ignore[assignment,misc] + if middleware is not None: + for cls, options in reversed(middleware): # type: ignore # for Pylance + self._app = cls(app=self._app, **options) # type: ignore[assignment,misc] + + async def __call__( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + + if "router" not in scope: + scope["router"] = self + + await self._app(scope, receive, send) # type: ignore[arg-type,call-arg,misc] - mount = _not_supported("mount") - host = _not_supported("host") - add_route = _not_supported("add_route") - add_websocket_route = _not_supported("add_websocket_route") - route = _not_supported("route") - websocket_route = _not_supported("websocket_route") - add_event_handler = _not_supported("add_event_handler") - on_event = _not_supported("on_event") + mount = not_supported("mount") + host = not_supported("host") + add_route = not_supported("add_route") + add_websocket_route = not_supported("add_websocket_route") + route = not_supported("route") + websocket_route = not_supported("websocket_route") + add_event_handler = not_supported("add_event_handler") + on_event = not_supported("on_event") diff --git a/xpresso/routing/websockets.py b/xpresso/routing/websockets.py index 72c72806..294193d5 100644 --- a/xpresso/routing/websockets.py +++ b/xpresso/routing/websockets.py @@ -54,6 +54,8 @@ async def __call__( class WebSocketRoute(starlette.routing.WebSocketRoute): + path: str + def __init__( self, path: str,