Skip to content

Commit

Permalink
fix: protocol forward compat
Browse files Browse the repository at this point in the history
This patch switches us from using bincode, which has
pretty terrible forwards compatibility properties
to using msgpack (I think I did a switch in the
previous direction in the past). bincode is faster,
but we don't actually care about encoding speed
since we are just encoding some tiny baby little
headers at the start of a connection. This
change also adds a bunch of #[serde(default)]
annotations, which should take care of the second
half of forwards compatibilty.

This means one last protocol version break before
hopefully being able to evolve the protocol in a
much more civilized manner.

I realized I was going to have to add a protocol
field for #47, so this should be done first.

Fixes #85
  • Loading branch information
ethanpailes committed Jul 15, 2024
1 parent e21c16b commit 61a6b27
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 18 deletions.
39 changes: 29 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libshpool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ crossbeam-channel = "0.5" # channels
libc = "0.2" # basic libc types
log = "0.4" # logging facade (not used directly, but required if we have tracing-log enabled)
tracing = "0.1" # logging and performance monitoring facade
bincode = "1" # serialization for the control protocol
rmp-serde = "1" # serialization for the control protocol
shpool_vt100 = "0.1.2" # terminal emulation for the scrollback buffer
shell-words = "1" # parsing the -c/--cmd argument
motd = { version = "0.2.2", default-features = false, features = [] } # getting the message-of-the-day
Expand Down
4 changes: 2 additions & 2 deletions libshpool/src/daemon/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ impl Server {
#[instrument(skip_all)]
fn parse_connect_header(stream: &mut UnixStream) -> anyhow::Result<protocol::ConnectHeader> {
let header: protocol::ConnectHeader =
bincode::deserialize_from(stream).context("parsing header")?;
protocol::decode_from(stream).context("parsing header")?;
Ok(header)
}

Expand All @@ -930,7 +930,7 @@ where
.context("setting write timout on inbound session")?;

let serializeable_stream = stream.try_clone().context("cloning stream handle")?;
bincode::serialize_into(serializeable_stream, &header).context("writing reply")?;
protocol::encode_to(&header, serializeable_stream).context("writing reply")?;

stream.set_write_timeout(None).context("unsetting write timout on inbound session")?;
Ok(())
Expand Down
68 changes: 63 additions & 5 deletions libshpool/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::{
default::Default,
fmt,
io::{self, Read, Write},
os::unix::net::UnixStream,
Expand All @@ -23,6 +24,7 @@ use std::{

use anyhow::{anyhow, Context};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use serde::{Deserialize, Serialize};
use serde_derive::{Deserialize, Serialize};
use tracing::{debug, error, instrument, span, trace, warn, Level};

Expand All @@ -31,6 +33,36 @@ use super::{consts, tty};
const JOIN_POLL_DUR: time::Duration = time::Duration::from_millis(100);
const JOIN_HANGUP_DUR: time::Duration = time::Duration::from_millis(300);

/// The centralized encoding function that should be used for all protocol
/// serialization.
pub fn encode_to<T, W>(d: &T, w: W) -> anyhow::Result<()>
where
T: Serialize,
W: Write,
{
// You might be worried that since we are encoding and decoding
// directly to/from the stream, unknown fields might be left trailing
// and mangle followup data, but msgpack is basically binary
// encoded json, so it has a notion of an object, which means
// it will be able to skip past the unknown fields and avoid any
// sort of mangling.
let mut serializer = rmp_serde::Serializer::new(w).with_struct_map();
d.serialize(&mut serializer).context("serializing data")?;
Ok(())
}

/// The centralized decoding focuntion that should be used for all protocol
/// deserialization.
pub fn decode_from<T, R>(r: R) -> anyhow::Result<T>
where
for<'de> T: Deserialize<'de>,
R: Read,
{
let mut deserializer = rmp_serde::Deserializer::new(r);
let d: T = Deserialize::deserialize(&mut deserializer).context("deserializing from reader")?;
Ok(d)
}

/// ConnectHeader is the blob of metadata that a client transmits when it
/// first connections. It uses an enum to allow different connection types
/// to be initiated on the same socket. The ConnectHeader is always prefixed
Expand Down Expand Up @@ -65,11 +97,13 @@ pub enum ConnectHeader {
#[derive(Serialize, Deserialize, Debug)]
pub struct KillRequest {
/// The sessions to detach
#[serde(default)]
pub sessions: Vec<String>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct KillReply {
#[serde(default)]
pub not_found_sessions: Vec<String>,
}

Expand All @@ -78,15 +112,18 @@ pub struct KillReply {
#[derive(Serialize, Deserialize, Debug)]
pub struct DetachRequest {
/// The sessions to detach
#[serde(default)]
pub sessions: Vec<String>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct DetachReply {
/// sessions that are not even in the session table
#[serde(default)]
pub not_found_sessions: Vec<String>,
/// sessions that are in the session table, but have no
/// tty attached
#[serde(default)]
pub not_attached_sessions: Vec<String>,
}

Expand All @@ -96,20 +133,23 @@ pub struct DetachReply {
#[derive(Serialize, Deserialize, Debug)]
pub struct SessionMessageRequest {
/// The session to route this request to.
#[serde(default)]
pub session_name: String,
/// The actual message to send to the session.
#[serde(default)]
pub payload: SessionMessageRequestPayload,
}

/// SessionMessageRequestPayload contains a request for
/// a running session.
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default)]
pub enum SessionMessageRequestPayload {
/// Resize a named session's pty. Generated when
/// a `shpool attach` process receives a SIGWINCH.
Resize(ResizeRequest),
/// Detach the given session. Generated internally
/// by the server from a batch detach request.
#[default]
Detach,
}

Expand All @@ -120,6 +160,7 @@ pub enum SessionMessageRequestPayload {
#[derive(Serialize, Deserialize, Debug)]
pub struct ResizeRequest {
/// The size of the client's tty
#[serde(default)]
pub tty_size: tty::Size,
}

Expand Down Expand Up @@ -154,22 +195,27 @@ pub enum ResizeReply {
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct AttachHeader {
/// The name of the session to create or attach to.
#[serde(default)]
pub name: String,
/// The size of the local tty. Passed along so that the remote
/// pty can be kept in sync (important so curses applications look
/// right).
#[serde(default)]
pub local_tty_size: tty::Size,
/// A subset of the environment of the shell that `shpool attach` is run
/// in. Contains only some variables needed to set up the shell when
/// shpool forks off a process. For now the list is just `SSH_AUTH_SOCK`
/// and `TERM`.
#[serde(default)]
pub local_env: Vec<(String, String)>,
/// If specified, sets a time limit on how long the shell will be open
/// when the shell is first created (does nothing in the case of a
/// reattach). The daemon is responsible for automatically killing the
/// session once the ttl is over.
#[serde(default)]
pub ttl_secs: Option<u64>,
/// If specified, a command to run instead of the users default shell.
#[serde(default)]
pub cmd: Option<String>,
}

Expand All @@ -184,26 +230,32 @@ impl AttachHeader {
/// connection error.
#[derive(Serialize, Deserialize, Debug)]
pub struct AttachReplyHeader {
#[serde(default)]
pub status: AttachStatus,
}

/// ListReply is contains a list of active sessions to be displayed to the user.
#[derive(Serialize, Deserialize, Debug)]
pub struct ListReply {
#[serde(default)]
pub sessions: Vec<Session>,
}

/// Session describes an active session.
#[derive(Serialize, Deserialize, Debug)]
pub struct Session {
#[serde(default)]
pub name: String,
#[serde(default)]
pub started_at_unix_ms: i64,
#[serde(default)]
pub status: SessionStatus,
}

/// Indicates if a shpool session currently has a client attached.
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default)]
pub enum SessionStatus {
#[default]
Attached,
Disconnected,
}
Expand Down Expand Up @@ -243,6 +295,12 @@ pub enum AttachStatus {
UnexpectedError(String),
}

impl Default for AttachStatus {
fn default() -> Self {
AttachStatus::UnexpectedError(String::from("default"))
}
}

/// ChunkKind is a tag that indicates what type of frame is being transmitted
/// through the socket.
#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -345,16 +403,16 @@ impl Client {

pub fn write_connect_header(&mut self, header: ConnectHeader) -> anyhow::Result<()> {
let serialize_stream = self.stream.try_clone().context("cloning stream for reply")?;
bincode::serialize_into(serialize_stream, &header).context("writing reply")?;
encode_to(&header, serialize_stream).context("writing reply")?;

Ok(())
}

pub fn read_reply<R>(&mut self) -> anyhow::Result<R>
where
R: serde::de::DeserializeOwned,
R: for<'de> serde::Deserialize<'de>,
{
let reply: R = bincode::deserialize_from(&mut self.stream).context("parsing header")?;
let reply: R = decode_from(&mut self.stream).context("parsing header")?;
Ok(reply)
}

Expand Down

0 comments on commit 61a6b27

Please sign in to comment.