Skip to content

Commit

Permalink
wip: support routes that match all methods
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Jun 4, 2024
1 parent 4a00c1f commit 2960894
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 72 deletions.
13 changes: 8 additions & 5 deletions core/codegen/src/attribute/route/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ fn internal_uri_macro_decl(route: &Route) -> TokenStream {
// Generate a unique macro name based on the route's metadata.
let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX);
let inner_macro_name = macro_name.uniqueify_with(|mut hasher| {
route.attr.method.0.hash(&mut hasher);
route.attr.method.as_ref().map(|m| m.0.hash(&mut hasher));
route.attr.uri.path().hash(&mut hasher);
route.attr.uri.query().hash(&mut hasher);
route.attr.data.as_ref().map(|d| d.value.hash(&mut hasher));
Expand Down Expand Up @@ -363,7 +363,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
let internal_uri_macro = internal_uri_macro_decl(&route);
let responder_outcome = responder_outcome_expr(&route);

let method = &route.attr.method;
let method = Optional(route.attr.method.clone());
let uri = route.attr.uri.to_string();
let rank = Optional(route.attr.rank);
let format = Optional(route.attr.format.as_ref());
Expand Down Expand Up @@ -448,9 +448,12 @@ fn incomplete_route(
let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?;

let attribute = Attribute {
method: SpanWrapped {
full_span: method_span, key_span: None, span: method_span, value: Method(method)
},
method: Some(SpanWrapped {
full_span: method_span,
key_span: None,
span: method_span,
value: Method(method),
}),
uri: method_attribute.uri,
data: method_attribute.data,
format: method_attribute.format,
Expand Down
25 changes: 15 additions & 10 deletions core/codegen/src/attribute/route/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ pub struct Arguments {
/// The parsed `#[route(..)]` attribute.
#[derive(Debug, FromMeta)]
pub struct Attribute {
#[meta(naked)]
pub method: SpanWrapped<Method>,
pub uri: RouteUri,
pub method: Option<SpanWrapped<Method>>,
pub data: Option<SpanWrapped<Dynamic>>,
pub format: Option<MediaType>,
pub rank: Option<isize>,
Expand Down Expand Up @@ -129,17 +128,23 @@ impl Route {
// Emit a warning if a `data` param was supplied for non-payload methods.
if let Some(ref data) = attr.data {
let lint = Lint::DubiousPayload;
match attr.method.0.allows_request_body() {
None if lint.enabled(handler.span()) => {
match attr.method.as_ref() {
Some(m) if m.0.allows_request_body() == Some(false) => {
diags.push(data.full_span
.error("`data` cannot be used on this route")
.span_note(m.span, "method does not support request payloads"))
},
Some(m) if m.0.allows_request_body().is_none() && lint.enabled(handler.span()) => {
data.full_span.warning("`data` used with non-payload-supporting method")
.note(format!("'{}' does not typically support payloads", attr.method.0))
.span_note(m.span, format!("'{}' does not typically support payloads", m.0))
.note(lint.how_to_suppress())
.emit_as_item_tokens();
},
None if lint.enabled(handler.span()) => {
data.full_span.warning("`data` used on route with wildcard method")
.note("some methods may not support request payloads")
.note(lint.how_to_suppress())
.emit_as_item_tokens();
}
Some(false) => {
diags.push(data.full_span
.error("`data` cannot be used on this route")
.span_note(attr.method.span, "method does not support request payloads"))
}
_ => { /* okay */ },
}
Expand Down
2 changes: 1 addition & 1 deletion core/codegen/tests/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn post1(
}

#[route(
POST,
method = POST,
uri = "/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>",
format = "json",
data = "<simple>",
Expand Down
4 changes: 4 additions & 0 deletions core/http/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ macro_rules! define_methods {
#[doc(hidden)]
pub const ALL: &'static [&'static str] = &[$($name),*];

/// A slice containing every defined method variant.
#[doc(hidden)]
pub const ALL_VARIANTS: &'static [Method] = &[$(Self::$V),*];

/// Whether the method is considered "safe".
///
/// From [RFC9110 §9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1):
Expand Down
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
cookie = { version = "0.18", features = ["percent-encode"] }
futures = { version = "0.3.30", default-features = false, features = ["std"] }
state = "0.6"
rustc-hash = "1.1"

# tracing
tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] }
Expand Down
18 changes: 11 additions & 7 deletions core/lib/src/route/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ pub struct Route {
/// The name of this route, if one was given.
pub name: Option<Cow<'static, str>>,
/// The method this route matches against.
pub method: Method,
pub method: Option<Method>,
/// The function that should be called when the route matches.
pub handler: Box<dyn Handler>,
/// The route URI.
Expand Down Expand Up @@ -207,8 +207,8 @@ impl Route {
/// assert_eq!(index.uri, "/");
/// ```
#[track_caller]
pub fn new<H: Handler>(method: Method, uri: &str, handler: H) -> Route {
Route::ranked(None, method, uri, handler)
pub fn new<M: Into<Option<Method>>, H: Handler>(method: M, uri: &str, handler: H) -> Route {
Route::ranked(None, method.into(), uri, handler)
}

/// Creates a new route with the given rank, method, path, and handler with
Expand Down Expand Up @@ -242,8 +242,10 @@ impl Route {
/// assert_eq!(foo.uri, "/foo?bar");
/// ```
#[track_caller]
pub fn ranked<H, R>(rank: R, method: Method, uri: &str, handler: H) -> Route
where H: Handler + 'static, R: Into<Option<isize>>,
pub fn ranked<M, H, R>(rank: R, method: M, uri: &str, handler: H) -> Route
where M: Into<Option<Method>>,
H: Handler + 'static,
R: Into<Option<isize>>,
{
let uri = RouteUri::new("/", uri);
let rank = rank.into().unwrap_or_else(|| uri.default_rank());
Expand All @@ -253,7 +255,9 @@ impl Route {
sentinels: Vec::new(),
handler: Box::new(handler),
location: None,
rank, uri, method,
method: method.into(),
rank,
uri,
}
}

Expand Down Expand Up @@ -362,7 +366,7 @@ pub struct StaticInfo {
/// The route's name, i.e, the name of the function.
pub name: &'static str,
/// The route's method.
pub method: Method,
pub method: Option<Method>,
/// The route's URi, without the base mount point.
pub uri: &'static str,
/// The route's format, if any.
Expand Down
14 changes: 11 additions & 3 deletions core/lib/src/router/collider.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::catcher::Catcher;
use crate::route::{Route, Segment, RouteUri};

use crate::http::MediaType;
use crate::http::{MediaType, Method};

pub trait Collide<T = Self> {
fn collides_with(&self, other: &T) -> bool;
Expand Down Expand Up @@ -87,7 +87,7 @@ impl Route {
/// assert!(a.collides_with(&b));
/// ```
pub fn collides_with(&self, other: &Route) -> bool {
self.method == other.method
methods_collide(self, other)
&& self.rank == other.rank
&& self.uri.collides_with(&other.uri)
&& formats_collide(self, other)
Expand Down Expand Up @@ -190,8 +190,16 @@ impl Collide for MediaType {
}
}

fn methods_collide(route: &Route, other: &Route) -> bool {
match (route.method, other.method) {
(Some(a), Some(b)) => a == b,
(None, _) | (_, None) => true,
}
}

fn formats_collide(route: &Route, other: &Route) -> bool {
match (route.method.allows_request_body(), other.method.allows_request_body()) {
let payload_support = |m: &Option<Method>| m.and_then(|m| m.allows_request_body());
match (payload_support(&route.method), payload_support(&other.method)) {
// Payload supporting methods match against `Content-Type` which must be
// fully specified, so the request cannot contain a format that matches
// more than one route format as long as those formats don't collide.
Expand Down
10 changes: 7 additions & 3 deletions core/lib/src/router/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ impl Route {
/// ```
#[tracing::instrument(level = "trace", name = "matching", skip_all, ret)]
pub fn matches(&self, request: &Request<'_>) -> bool {
trace!(route.method = %self.method, request.method = %request.method());
self.method == request.method()
methods_match(self, request)
&& paths_match(self, request)
&& queries_match(self, request)
&& formats_match(self, request)
Expand Down Expand Up @@ -140,6 +139,11 @@ impl Catcher {
}
}

fn methods_match(route: &Route, req: &Request<'_>) -> bool {
trace!(?route.method, request.method = %req.method());
route.method.map_or(true, |method| method == req.method())
}

fn paths_match(route: &Route, req: &Request<'_>) -> bool {
trace!(route.uri = %route.uri, request.uri = %req.uri());
let route_segments = &route.uri.metadata.uri_segments;
Expand Down Expand Up @@ -208,7 +212,7 @@ fn formats_match(route: &Route, req: &Request<'_>) -> bool {
None => return true,
};

match route.method.allows_request_body() {
match route.method.and_then(|m| m.allows_request_body()) {
Some(true) => match req.format() {
Some(f) if f.specificity() == 2 => route_format.collides_with(f),
_ => false
Expand Down
46 changes: 34 additions & 12 deletions core/lib/src/router/router.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::collections::HashMap;
use rustc_hash::FxHashMap;

use crate::request::Request;
use crate::http::{Method, Status};

use crate::{Route, Catcher};
use crate::router::Collide;

#[derive(Debug, Default)]
pub(crate) struct Router {
routes: HashMap<Method, Vec<Route>>,
catchers: HashMap<Option<u16>, Vec<Catcher>>,
routes: FxHashMap<Option<Method>, Vec<Route>>,
final_routes: FxHashMap<Method, Vec<Route>>,
catchers: FxHashMap<Option<u16>, Vec<Catcher>>,
}

pub type Collisions<T> = Vec<(T, T)>;
Expand Down Expand Up @@ -45,8 +45,9 @@ impl Router {
&'a self,
req: &'r Request<'r>
) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by ascending rank on each `add`.
self.routes.get(&req.method())
// Note that routes are presorted by ascending rank on each `add` and
// that all routes with `None` methods have been cloned into all methods.
self.final_routes.get(&req.method())
.into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
}
Expand Down Expand Up @@ -80,14 +81,35 @@ impl Router {
})
}

pub fn finalize(&self) -> Result<(), (Collisions<Route>, Collisions<Catcher>)> {
fn _add_route(map: &mut FxHashMap<Method, Vec<Route>>, method: Method, route: Route) {
let routes = map.entry(method).or_default();
routes.push(route);
routes.sort_by_key(|r| r.rank);
}

pub fn finalize(&mut self) -> Result<(), (Collisions<Route>, Collisions<Catcher>)> {
let routes: Vec<_> = self.collisions(self.routes()).collect();
let catchers: Vec<_> = self.collisions(self.catchers()).collect();

if !routes.is_empty() || !catchers.is_empty() {
return Err((routes, catchers))
}

let all_routes = self.routes.iter()
.flat_map(|(method, routes)| routes.iter().map(|r| (*method, r)))
.map(|(method, route)| (method, route.clone()));

let mut final_routes = FxHashMap::default();
for (method, route) in all_routes {
match method {
Some(method) => Self::_add_route(&mut final_routes, method, route),
None => for method in Method::ALL_VARIANTS {
Self::_add_route(&mut final_routes, *method, route.clone());
}
}
}

self.final_routes = final_routes;
Ok(())
}
}
Expand All @@ -101,7 +123,7 @@ mod test {
use crate::http::{Method::*, uri::Origin};

impl Router {
fn has_collisions(&self) -> bool {
fn has_collisions(&mut self) -> bool {
self.finalize().is_err()
}
}
Expand Down Expand Up @@ -137,12 +159,12 @@ mod test {
}

fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes);
let mut router = router_with_rankless_routes(routes);
router.has_collisions()
}

fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes);
let mut router = router_with_routes(routes);
router.has_collisions()
}

Expand Down Expand Up @@ -367,7 +389,7 @@ mod test {
/// Asserts that `$to` routes to `$want` given `$routes` are present.
macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes);
let mut router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string(),
Expand Down Expand Up @@ -401,7 +423,7 @@ mod test {
}

fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes);
let mut router = router_with_ranked_routes(routes);
router.has_collisions()
}

Expand Down
5 changes: 4 additions & 1 deletion core/lib/src/trace/traceable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ impl Trace for Route {
event! { level, "route",
name = self.name.as_ref().map(|n| &**n),
rank = self.rank,
method = %self.method,
method = %Formatter(|f| match self.method {
Some(method) => write!(f, "{}", method),
None => write!(f, "*"),
}),
uri = %self.uri,
uri.base = %self.uri.base(),
uri.unmounted = %self.uri.unmounted(),
Expand Down
4 changes: 2 additions & 2 deletions core/lib/tests/form_method-issue-45.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ fn patch(form_data: Form<FormData>) -> &'static str {
"PATCH OK"
}

#[route(UPDATEREDIRECTREF, uri = "/", data = "<form_data>")]
#[route(method = UPDATEREDIRECTREF, uri = "/", data = "<form_data>")]
fn urr(form_data: Form<FormData>) -> &'static str {
assert_eq!("Form data", form_data.into_inner().form_data);
"UPDATEREDIRECTREF OK"
}

#[route("VERSION-CONTROL", uri = "/", data = "<form_data>")]
#[route(method = "VERSION-CONTROL", uri = "/", data = "<form_data>")]
fn vc(form_data: Form<FormData>) -> &'static str {
assert_eq!("Form data", form_data.into_inner().form_data);
"VERSION-CONTROL OK"
Expand Down
6 changes: 5 additions & 1 deletion examples/hello/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn mir() -> &'static str {

// Try visiting:
// http://127.0.0.1:8000/wave/Rocketeer/100
#[get("/<name>/<age>")]
#[get("/<name>/<age>", rank = 2)]
fn wave(name: &str, age: u8) -> String {
format!("👋 Hello, {} year old named {}!", age, name)
}
Expand Down Expand Up @@ -72,9 +72,13 @@ fn hello(lang: Option<Lang>, opt: Options<'_>) -> String {
greeting
}

#[route(uri = "/<_..>", rank = 3)]
fn wild() { }

#[launch]
fn rocket() -> _ {
rocket::build()
.mount("/", routes![wild])
.mount("/", routes![hello])
.mount("/hello", routes![world, mir])
.mount("/wave", routes![wave])
Expand Down
Loading

0 comments on commit 2960894

Please sign in to comment.