From f22bf5fc8deeef6ac3654d170203474247976da5 Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Sun, 10 Mar 2024 15:57:51 -0400 Subject: [PATCH] Handoff to instance now works --- Cargo.lock | 7 +- Cargo.toml | 2 +- crates/nexus-voicechat/Cargo.toml | 11 -- crates/nexus-voicechat/src/lib.rs | 14 -- .../client/examples/example-client.rs | 16 +- crates/replicate/client/src/instance.rs | 99 ++++++------- crates/replicate/client/src/manager.rs | 18 ++- crates/replicate/common/Cargo.toml | 1 + .../replicate/common/src/messages/instance.rs | 11 ++ .../replicate/common/src/messages/manager.rs | 3 + crates/replicate/common/src/messages/mod.rs | 1 + crates/replicate/server/Cargo.toml | 2 + crates/replicate/server/src/chad/mod.rs | 138 +++++++++++++++--- 13 files changed, 222 insertions(+), 101 deletions(-) delete mode 100644 crates/nexus-voicechat/Cargo.toml delete mode 100644 crates/nexus-voicechat/src/lib.rs create mode 100644 crates/replicate/common/src/messages/instance.rs diff --git a/Cargo.lock b/Cargo.lock index 6e60070..46761c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5191,10 +5191,6 @@ dependencies = [ "jni-sys", ] -[[package]] -name = "nexus-voicechat" -version = "0.0.0" - [[package]] name = "nix" version = "0.24.3" @@ -6357,6 +6353,7 @@ dependencies = [ "tokio", "tokio-serde", "tokio-util", + "url", "uuid 1.7.0", ] @@ -6369,6 +6366,7 @@ dependencies = [ "clap", "color-eyre", "dashmap", + "derive_more", "eyre", "futures", "replicate-common", @@ -6378,6 +6376,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "url", "uuid 1.7.0", "wtransport", ] diff --git a/Cargo.toml b/Cargo.toml index 2103853..e2aa7c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,6 @@ members = [ "apps/social/networking", "apps/social/server", "crates/egui-picking", - "crates/nexus-voicechat", "crates/picking-xr", "crates/replicate/client", "crates/replicate/common", @@ -74,6 +73,7 @@ features = [ "deref", "deref_mut", "mul", + "from", ] [workspace.dependencies.opus] diff --git a/crates/nexus-voicechat/Cargo.toml b/crates/nexus-voicechat/Cargo.toml deleted file mode 100644 index af44080..0000000 --- a/crates/nexus-voicechat/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "nexus-voicechat" -version.workspace = true -license.workspace = true -repository.workspace = true -edition.workspace = true -rust-version.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] diff --git a/crates/nexus-voicechat/src/lib.rs b/crates/nexus-voicechat/src/lib.rs deleted file mode 100644 index 06d268d..0000000 --- a/crates/nexus-voicechat/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} diff --git a/crates/replicate/client/examples/example-client.rs b/crates/replicate/client/examples/example-client.rs index f963577..7b2424c 100644 --- a/crates/replicate/client/examples/example-client.rs +++ b/crates/replicate/client/examples/example-client.rs @@ -1,6 +1,6 @@ use clap::Parser; use color_eyre::{eyre::WrapErr, Result}; -use replicate_client::manager::Manager; +use replicate_client::{instance::Instance, manager::Manager}; use replicate_common::did::{AuthenticationAttestation, Did, DidPrivateKey}; use tracing::info; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; @@ -36,7 +36,7 @@ async fn main() -> Result<()> { let auth_attest = AuthenticationAttestation::new(did, &did_private_key); - let mut manager = Manager::connect(args.url, auth_attest) + let mut manager = Manager::connect(args.url, &auth_attest) .await .wrap_err("failed to connect to manager")?; info!("Connected to manager!"); @@ -45,7 +45,17 @@ async fn main() -> Result<()> { .instance_create() .await .wrap_err("failed to create instance")?; - info!("Got instance: {instance_id}"); + + let instance_url = manager + .instance_url(instance_id) + .await + .wrap_err("failed to get instance url")?; + info!("Got instance {instance_id} at: {instance_url}"); + + let _instance = Instance::connect(instance_url, auth_attest) + .await + .wrap_err("failed to connect to instance")?; + info!("Connected to instance!"); Ok(()) } diff --git a/crates/replicate/client/src/instance.rs b/crates/replicate/client/src/instance.rs index c6c416c..882c5fc 100644 --- a/crates/replicate/client/src/instance.rs +++ b/crates/replicate/client/src/instance.rs @@ -82,27 +82,34 @@ use std::sync::{atomic::AtomicU16, RwLock, RwLockReadGuard, RwLockWriteGuard}; use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD}; +use eyre::{bail, ensure, Result, WrapErr}; +use futures::{SinkExt, StreamExt}; use replicate_common::{ data_model::{DataModel, Entity, State}, did::AuthenticationAttestation, }; use tracing::warn; use url::Url; -use wtransport::{endpoint::ConnectOptions, ClientConfig, Endpoint, RecvStream}; +use wtransport::{endpoint::ConnectOptions, ClientConfig, Endpoint}; use crate::CertHashDecodeErr; +use replicate_common::messages::instance::{Clientbound as Cb, Serverbound as Sb}; +type RpcFramed = replicate_common::Framed; + /// Client api for interacting with a particular instance on the server. /// Instances manage persistent, realtime state updates for many concurrent clients. #[derive(Debug)] pub struct Instance { _conn: wtransport::Connection, - _url: String, + _url: Url, /// Used to reliably push state updates from server to client. This happens for all /// entities when the client initially connects, as well as when the server is /// marking an entity as "stable", meaning its state is no longer changing frame to /// frame. This allows the server to reduce network bandwidth. - _stable_states: RecvStream, + // _stable_states: RecvStream, + /// Used for general RPC. + _rpc: RpcFramed, /// Current sequence number. // TODO: Figure out how sequence numbers work _state_seq: StateSeq, @@ -118,9 +125,42 @@ impl Instance { pub async fn connect( url: Url, auth_attest: AuthenticationAttestation, - ) -> Result { - let _conn = connect_to_url(url, auth_attest).await?; - todo!() + ) -> Result { + let conn = connect_to_url(&url, auth_attest) + .await + .wrap_err("failed to connect to server")?; + + let bi = wtransport::stream::BiStream::join( + conn.open_bi() + .await + .wrap_err("could not initiate bi stream")? + .await + .wrap_err("could not finish opening bi stream")?, + ); + let mut rpc = RpcFramed::new(bi); + + // Do handshake before anything else + { + rpc.send(Sb::HandshakeRequest) + .await + .wrap_err("failed to send handshake request")?; + let Some(msg) = rpc.next().await else { + bail!("Server disconnected before completing handshake"); + }; + let msg = msg.wrap_err("error while receiving handshake response")?; + ensure!( + msg == Cb::HandshakeResponse, + "invalid message during handshake" + ); + } + + Ok(Self { + _conn: conn, + _url: url, + _state_seq: Default::default(), + dm: RwLock::new(DataModel::new()), + _rpc: rpc, + }) } /// Asks the server to reserve for this client a list of entity ids and store them @@ -130,7 +170,7 @@ impl Instance { pub async fn reserve_entities( &self, #[allow(clippy::ptr_arg)] _entities: &mut Vec, - ) -> Result<(), ReserveErr> { + ) -> Result<()> { todo!() } @@ -154,52 +194,13 @@ pub enum RecvState<'a> { } /// Sequence number for state messages -#[derive(Debug)] +#[derive(Debug, Default)] pub struct StateSeq(AtomicU16); -mod error { - use crate::CertHashDecodeErr; - use wtransport::error::{ConnectingError, SendDatagramError, StreamWriteError}; - - #[derive(thiserror::Error, Debug)] - pub enum ReserveErr {} - - #[derive(thiserror::Error, Debug)] - pub enum DeleteErr {} - - #[derive(thiserror::Error, Debug)] - pub enum SendStateErr { - #[error("error while sending state across network: {0}")] - Dgram(#[from] SendDatagramError), - } - - #[derive(thiserror::Error, Debug)] - pub enum SendReliableStateErr { - #[error("error while finalizing state to network: {0}")] - StreamWrite(#[from] StreamWriteError), - } - - #[derive(thiserror::Error, Debug)] - pub enum RecvStateErr {} - - #[derive(thiserror::Error, Debug)] - pub enum ConnectErr { - #[error("failed to create webtransport client: {0}")] - ClientCreate(#[from] std::io::Error), - #[error("failed to connect to webtransport endoint: {0}")] - WtConnectingError(#[from] ConnectingError), - #[error(transparent)] - InvalidCertHash(#[from] CertHashDecodeErr), - #[error(transparent)] - Other(#[from] Box), - } -} -pub use self::error::*; - async fn connect_to_url( - url: Url, + url: &Url, auth_attest: AuthenticationAttestation, -) -> Result { +) -> Result { let cert_hash = if let Some(frag) = url.fragment() { let cert_hash = BASE64_URL_SAFE_NO_PAD .decode(frag) diff --git a/crates/replicate/client/src/manager.rs b/crates/replicate/client/src/manager.rs index 3b47448..687b83f 100644 --- a/crates/replicate/client/src/manager.rs +++ b/crates/replicate/client/src/manager.rs @@ -44,7 +44,7 @@ impl Manager { /// our DID. pub async fn connect( url: Url, - auth_attest: AuthenticationAttestation, + auth_attest: &AuthenticationAttestation, ) -> Result { let cert_hash = if let Some(frag) = url.fragment() { let cert_hash = BASE64_URL_SAFE_NO_PAD @@ -105,6 +105,7 @@ impl Manager { "invalid message during handshake" ); } + Ok(Self { _conn: conn, _url: url, @@ -126,4 +127,19 @@ impl Manager { Some(Ok(_)) => Err(eyre!("unexpected response")), } } + + pub async fn instance_url(&mut self, id: InstanceId) -> Result { + self.framed + .send(Sb::InstanceUrlRequest { id }) + .await + .wrap_err("failed to write message")?; + match self.framed.next().await { + None => Err(eyre!("server disconnected")), + Some(Err(err)) => { + Err(eyre::Report::new(err).wrap_err("failed to receive message")) + } + Some(Ok(Cb::InstanceUrlResponse { url })) => Ok(url), + Some(Ok(_)) => Err(eyre!("unexpected response")), + } + } } diff --git a/crates/replicate/common/Cargo.toml b/crates/replicate/common/Cargo.toml index 2ad5167..44e8820 100644 --- a/crates/replicate/common/Cargo.toml +++ b/crates/replicate/common/Cargo.toml @@ -16,4 +16,5 @@ thiserror.workspace = true tokio = { workspace = true, default-features = false } tokio-serde.workspace = true tokio-util.workspace = true +url.workspace = true uuid = { workspace = true, features = ["v4", "serde"] } diff --git a/crates/replicate/common/src/messages/instance.rs b/crates/replicate/common/src/messages/instance.rs new file mode 100644 index 0000000..73591cf --- /dev/null +++ b/crates/replicate/common/src/messages/instance.rs @@ -0,0 +1,11 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Eq, PartialEq)] +pub enum Serverbound { + HandshakeRequest, +} + +#[derive(Serialize, Deserialize, Eq, PartialEq)] +pub enum Clientbound { + HandshakeResponse, +} diff --git a/crates/replicate/common/src/messages/manager.rs b/crates/replicate/common/src/messages/manager.rs index cb78052..f3ede7f 100644 --- a/crates/replicate/common/src/messages/manager.rs +++ b/crates/replicate/common/src/messages/manager.rs @@ -1,15 +1,18 @@ use serde::{Deserialize, Serialize}; +use url::Url; use crate::InstanceId; #[derive(Serialize, Deserialize, Eq, PartialEq)] pub enum Serverbound { + InstanceUrlRequest { id: InstanceId }, InstanceCreateRequest, HandshakeRequest, } #[derive(Serialize, Deserialize, Eq, PartialEq)] pub enum Clientbound { + InstanceUrlResponse { url: Url }, InstanceCreateResponse { id: InstanceId }, HandshakeResponse, } diff --git a/crates/replicate/common/src/messages/mod.rs b/crates/replicate/common/src/messages/mod.rs index c8dc873..80599a6 100644 --- a/crates/replicate/common/src/messages/mod.rs +++ b/crates/replicate/common/src/messages/mod.rs @@ -2,4 +2,5 @@ //! We should switch to protobuf or capnproto as soon as we prove the networking //! works. +pub mod instance; pub mod manager; diff --git a/crates/replicate/server/Cargo.toml b/crates/replicate/server/Cargo.toml index a94f994..2e070d0 100644 --- a/crates/replicate/server/Cargo.toml +++ b/crates/replicate/server/Cargo.toml @@ -13,6 +13,7 @@ bytes.workspace = true clap.workspace = true color-eyre.workspace = true dashmap = "5.5.3" +derive_more.workspace = true eyre.workspace = true futures.workspace = true replicate-common.path = "../common" @@ -21,6 +22,7 @@ tokio-serde = { workspace = true, features = ["json"] } tokio-util = { workspace = true, features = ["codec"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing.workspace = true +url.workspace = true uuid = { version = "1.6.1", features = ["v4", "serde"] } wtransport.workspace = true diff --git a/crates/replicate/server/src/chad/mod.rs b/crates/replicate/server/src/chad/mod.rs index 0b87fca..64ddf75 100644 --- a/crates/replicate/server/src/chad/mod.rs +++ b/crates/replicate/server/src/chad/mod.rs @@ -1,6 +1,10 @@ //! WebTransport server, i.e. "chad" transport. -use std::{num::Wrapping, sync::Arc, time::Duration}; +use std::{ + num::Wrapping, + sync::{Arc, RwLock}, + time::Duration, +}; use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD}; use color_eyre::{eyre::Context, Result}; @@ -11,7 +15,8 @@ use replicate_common::{ InstanceId, }; use tracing::{error, info, info_span, Instrument}; -use wtransport::{endpoint::IncomingSession, Certificate, ServerConfig}; +use url::Url; +use wtransport::{endpoint::IncomingSession, ServerConfig}; use crate::{instance::InstanceManager, Args}; @@ -24,15 +29,13 @@ pub async fn launch_webtransport_server( args: Args, _im: Arc, ) -> Result<()> { - let mut cert = Certificate::self_signed(args.subject_alt_names.iter()); - let domain_name = args - .subject_alt_names - .first() - .expect("should be at least one SAN"); + let cert = Certificate::new(wtransport::Certificate::self_signed( + args.subject_alt_names.iter(), + )); let server = Server::server( ServerConfig::builder() .with_bind_default(args.port.unwrap_or(0)) - .with_certificate(cert.clone()) + .with_certificate(cert.cert.clone()) .build(), ) .wrap_err("failed to create wtransport server")?; @@ -42,16 +45,26 @@ pub async fn launch_webtransport_server( .expect("could not determine port") .port(); - info!("server url:\n{}", server_url(domain_name, port, &cert)); + let svr_ctx = ServerCtx::new(ServerCtxInner { + san: args.subject_alt_names, + port, + cert, + }); + + { + let svr_ctx = svr_ctx.0.read().expect("lock poisoned"); + info!("server url:\n{}", server_url(&svr_ctx)); + } let mut id = Wrapping(0u64); let accept_fut = async { loop { let incoming = server.accept().await; + let svr_ctx_clone = svr_ctx.clone(); id += 1; tokio::spawn( async move { - if let Err(err) = handle_connection(incoming).await { + if let Err(err) = handle_connection(svr_ctx_clone, incoming).await { error!("terminated with error: {err:?}"); } else { info!("disconnected"); @@ -69,13 +82,15 @@ pub async fn launch_webtransport_server( loop { interval.tick().await; info!("refreshing certs"); - cert = Certificate::self_signed(args.subject_alt_names.iter()); - #[allow(clippy::question_mark)] + let mut svr_ctx_l = svr_ctx.0.write().expect("server context poisoned"); + svr_ctx_l.cert = + wtransport::Certificate::self_signed(svr_ctx_l.san.iter()).into(); + if let Err(err) = server .reload_config( ServerConfig::builder() .with_bind_default(args.port.unwrap_or(0)) - .with_certificate(cert.clone()) + .with_certificate(svr_ctx_l.cert.cert.clone()) .build(), false, ) @@ -83,7 +98,7 @@ pub async fn launch_webtransport_server( { return Err(err); } - info!("new server url:\n{}", server_url(domain_name, port, &cert)); + info!("new server url:\n{}", server_url(&svr_ctx_l)); } } .instrument(info_span!("cert refresh task")); @@ -93,7 +108,10 @@ pub async fn launch_webtransport_server( } } -async fn handle_connection(incoming: IncomingSession) -> Result<()> { +async fn handle_connection( + svr_ctx: ServerCtx, + incoming: IncomingSession, +) -> Result<()> { info!("Waiting for session request..."); let session_request = incoming.await?; info!( @@ -137,6 +155,20 @@ async fn handle_connection(incoming: IncomingSession) -> Result<()> { }; framed.send(response).await?; } + Sb::InstanceUrlRequest { id } => { + let url = { + let svr_ctx_l = svr_ctx.0.read().expect("poisoned"); + let domain_name = + svr_ctx_l.san.first().expect("should have domain name"); + let port = svr_ctx_l.port; + let hash = &svr_ctx_l.cert.base64; + // TODO: Actually manipulate the instance manager. + Url::parse(&format!("https://{domain_name}:{port}/{id}/#{hash}")) + .expect("invalid url") + }; + let response = Cb::InstanceUrlResponse { url }; + framed.send(response).await?; + } Sb::HandshakeRequest => { bail!("already did handshake, another handshake is unexpected") } @@ -147,12 +179,82 @@ async fn handle_connection(incoming: IncomingSession) -> Result<()> { Ok(()) } -fn server_url(subject_alt_name: &str, port: u16, cert: &Certificate) -> String { - let cert_hash = cert.hashes().pop().expect("should be at least one hash"); - let encoded_cert_hash = BASE64_URL_SAFE_NO_PAD.encode(cert_hash.as_ref()); +fn server_url(svr_ctx: &ServerCtxInner) -> String { + let encoded_cert_hash = &svr_ctx.cert.base64; + let subject_alt_name = svr_ctx.san.first().expect("should have at least 1 SAN"); + let port = svr_ctx.port; format!("https://{subject_alt_name}:{port}/#{encoded_cert_hash}") } +/// Server state, concurrently shared across all connections +#[derive(Debug, Clone)] +struct ServerCtxInner { + san: Vec, + port: u16, + cert: Certificate, +} + +#[derive(Debug, Clone)] +struct ServerCtx(Arc>); + +impl ServerCtx { + fn new(ctx: ServerCtxInner) -> Self { + Self(Arc::new(RwLock::new(ctx))) + } +} + +/// Connection state +#[derive(Debug)] +struct ConnectionCtx {} + +/// Newtype on [`wtransport::Certificate`]. +#[derive(Clone)] +struct Certificate { + cert: wtransport::Certificate, + base64: String, +} + +impl Certificate { + fn new(cert: wtransport::Certificate) -> Self { + let cert_hash = cert.hashes().pop().expect("should be at least one hash"); + let base64 = BASE64_URL_SAFE_NO_PAD.encode(cert_hash.as_ref()); + Self { cert, base64 } + } +} + +impl From for Certificate { + fn from(value: wtransport::Certificate) -> Self { + Self::new(value) + } +} + +impl std::fmt::Debug for Certificate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(std::any::type_name::()) + .field(&self.cert.certificates()) + .finish() + } +} + +impl std::fmt::Display for Certificate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list() + .entries( + self.cert + .hashes() + .into_iter() + .map(|h| h.fmt(wtransport::tls::Sha256DigestFmt::DottedHex)), + ) + .finish() + } +} + +impl AsRef for Certificate { + fn as_ref(&self) -> &wtransport::Certificate { + &self.cert + } +} + #[cfg(test)] mod test { use std::time::Duration;