diff --git a/matchbox_signaling/examples/tls.rs b/matchbox_signaling/examples/tls.rs new file mode 100644 index 00000000..ed578f6d --- /dev/null +++ b/matchbox_signaling/examples/tls.rs @@ -0,0 +1,33 @@ +use matchbox_signaling::SignalingServer; +use std::net::Ipv4Addr; +use tracing::info; + +#[tokio::main] +async fn main() -> Result<(), matchbox_signaling::Error> { + setup_logging(); + + let server = SignalingServer::full_mesh_builder((Ipv4Addr::UNSPECIFIED, 3536)) + .on_connection_request(|connection| { + info!("Connecting: {connection:?}"); + Ok(true) // Allow all connections + }) + .on_id_assignment(|(socket, id)| info!("{socket} received {id}")) + .on_peer_connected(|id| info!("Joined: {id}")) + .on_peer_disconnected(|id| info!("Left: {id}")) + .tls("", "") + .await + .trace() + .build(); + server.serve().await +} + +fn setup_logging() { + use tracing_subscriber::prelude::*; + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); +} diff --git a/matchbox_signaling/src/signaling_server/builder.rs b/matchbox_signaling/src/signaling_server/builder.rs index 17c48574..40f480b1 100644 --- a/matchbox_signaling/src/signaling_server/builder.rs +++ b/matchbox_signaling/src/signaling_server/builder.rs @@ -8,7 +8,7 @@ use crate::{ SignalingCallbacks, SignalingServer, SignalingState, }; use axum::{response::Response, routing::get, Extension, Router}; -use axum_server::{accept::DefaultAcceptor, tls_rustls::RustlsConfig, Handle}; +use axum_server::{tls_rustls::RustlsConfig, Handle}; use matchbox_protocol::PeerId; use std::net::SocketAddr; use std::path::Path; @@ -23,7 +23,7 @@ use tracing::Level; /// /// Begin with [`SignalingServerBuilder::new`] and add parameters before calling /// [`SignalingServerBuilder::build`] to produce the desired [`SignalingServer`]. -pub struct SignalingServerBuilder +pub struct SignalingServerBuilder where Topology: SignalingTopology, Cb: SignalingCallbacks, @@ -47,8 +47,8 @@ where /// Arbitrary state accompanying a server pub(crate) state: S, - /// A server acceptor - pub(crate) acceptor: Tls, + /// Tls config + pub(crate) tls: Option, } impl SignalingServerBuilder @@ -58,19 +58,15 @@ where S: SignalingState, { /// Creates a new builder for a [`SignalingServer`]. - pub fn new( - socket_addr: impl Into, - topology: Topology, - state: S, - ) -> SignalingServerBuilder { - SignalingServerBuilder { + pub fn new(socket_addr: impl Into, topology: Topology, state: S) -> Self { + Self { socket_addr: socket_addr.into(), router: Router::new(), shared_callbacks: SharedCallbacks::default(), callbacks: Cb::default(), topology, state, - acceptor: DefaultAcceptor, + tls: None, } } @@ -101,21 +97,10 @@ where } /// Configure TLS with a certificate (.pem) and private key (.key) file - pub async fn tls( - self, - cert: impl AsRef, - key: impl AsRef, - ) -> SignalingServerBuilder { + pub async fn tls(mut self, cert: impl AsRef, key: impl AsRef) -> Self { let config = RustlsConfig::from_pem_file(cert, key).await.unwrap(); - SignalingServerBuilder { - socket_addr: self.socket_addr, - router: self.router, - shared_callbacks: self.shared_callbacks, - callbacks: self.callbacks, - topology: self.topology, - state: self.state, - acceptor: config, - } + self.tls.replace(config); + self } /// Apply permissive CORS middleware for debug purposes. @@ -141,15 +126,8 @@ where ); self } -} -impl SignalingServerBuilder -where - Topology: SignalingTopology, - Cb: SignalingCallbacks, - S: SignalingState, -{ - /// Create a [`SignalingServer`]. + // Create a [`SignalingServer`]. /// /// # Panics /// This method will panic if the socket address requested cannot be bound. @@ -165,46 +143,25 @@ where .layer(Extension(self.callbacks)) .layer(Extension(self.state)); let handle = Handle::new(); - SignalingServer { - server: Box::pin( - axum_server::bind(self.socket_addr) - .handle(handle.clone()) - .serve(router.into_make_service_with_connect_info::()), - ), - handle, - } - } -} -impl SignalingServerBuilder -where - Topology: SignalingTopology, - Cb: SignalingCallbacks, - S: SignalingState, -{ - /// Create a [`SignalingServer`]. - /// - /// # Panics - /// This method will panic if the socket address requested cannot be bound. - pub fn build(self) -> SignalingServer { - let state_machine: SignalingStateMachine = - SignalingStateMachine::from_topology(self.topology); - let router = self - .router - .route("/", get(ws_handler::)) - .route("/:path", get(ws_handler::)) - .layer(Extension(state_machine)) - .layer(Extension(self.shared_callbacks)) - .layer(Extension(self.callbacks)) - .layer(Extension(self.state)); - let handle = Handle::new(); - SignalingServer { - server: Box::pin( - axum_server::bind_rustls(self.socket_addr, self.acceptor) - .handle(handle.clone()) - .serve(router.into_make_service_with_connect_info::()), - ), - handle, + if let Some(config) = self.tls { + SignalingServer { + server: Box::pin( + axum_server::bind_rustls(self.socket_addr, config) + .handle(handle.clone()) + .serve(router.into_make_service_with_connect_info::()), + ), + handle, + } + } else { + SignalingServer { + server: Box::pin( + axum_server::bind(self.socket_addr) + .handle(handle.clone()) + .serve(router.into_make_service_with_connect_info::()), + ), + handle, + } } } } diff --git a/matchbox_signaling/src/signaling_server/server.rs b/matchbox_signaling/src/signaling_server/server.rs index a8897505..0682bd88 100644 --- a/matchbox_signaling/src/signaling_server/server.rs +++ b/matchbox_signaling/src/signaling_server/server.rs @@ -5,7 +5,7 @@ use crate::{ full_mesh::{FullMesh, FullMeshCallbacks, FullMeshState}, }, }; -use axum_server::{accept::DefaultAcceptor, Handle}; +use axum_server::Handle; use futures::Future; use std::{io, net::SocketAddr, pin::Pin}; @@ -24,22 +24,14 @@ impl SignalingServer { pub fn full_mesh_builder( socket_addr: impl Into, ) -> SignalingServerBuilder { - SignalingServerBuilder::<_, _, _, DefaultAcceptor>::new( - socket_addr, - FullMesh, - FullMeshState::default(), - ) + SignalingServerBuilder::new(socket_addr, FullMesh, FullMeshState::default()) } /// Creates a new builder for a [`SignalingServer`] with client-server topology. pub fn client_server_builder( socket_addr: impl Into, ) -> SignalingServerBuilder { - SignalingServerBuilder::<_, _, _, DefaultAcceptor>::new( - socket_addr, - ClientServer, - ClientServerState::default(), - ) + SignalingServerBuilder::new(socket_addr, ClientServer, ClientServerState::default()) } /// Returns a clone to a server handle for introspection.