Skip to content

Commit

Permalink
fix: add version negotiation warnings
Browse files Browse the repository at this point in the history
This patch adds some version negotiation logic
and appropriate warnings to shpool. Shpool remains
permissive about version mismatches, only issuing
warnings so that the user may continue to perform
any operations which still function correctly, but
it now provides recommendations on how to
address the issue.

The protocol version is determined by the new
shpool-protocol crate.

Fixes #88
  • Loading branch information
ethanpailes committed Aug 26, 2024
1 parent 25c9962 commit e62a2e8
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libshpool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ strip-ansi-escapes = "0.2.0" # cleaning up strings for pager display
notify = "6" # watch config file for updates
libproc = "0.14.8" # sniffing shells by examining the subprocess
daemonize = "0.5" # autodaemonization
shpool-protocol = { version = "0.1.0", path = "../shpool-protocol" } # client-server protocol
shpool-protocol = { version = "0.2.0", path = "../shpool-protocol" } # client-server protocol

# rusty wrapper for unix apis
[dependencies.nix]
Expand Down
23 changes: 20 additions & 3 deletions libshpool/src/attach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use shpool_protocol::{
};
use tracing::{error, info, warn};

use super::{config, duration, protocol, test_hooks, tty::TtySizeExt as _};
use super::{config, duration, protocol, protocol::ClientResult, test_hooks, tty::TtySizeExt as _};

const MAX_FORCE_RETRIES: usize = 20;

Expand Down Expand Up @@ -188,7 +188,18 @@ fn do_attach(

fn dial_client(socket: &PathBuf) -> anyhow::Result<protocol::Client> {
match protocol::Client::new(socket) {
Ok(c) => Ok(c),
Ok(ClientResult::JustClient(c)) => Ok(c),
Ok(ClientResult::VersionMismatch { warning, client }) => {
eprintln!("warning: {}, try restarting your daemon", warning);
eprintln!("hit enter to continue anyway or ^C to exit");

let _ = io::stdin()
.lines()
.next()
.context("waiting for a continue through a version mismatch")?;

Ok(client)
}
Err(err) => {
let io_err = err.downcast::<io::Error>()?;
if io_err.kind() == io::ErrorKind::NotFound {
Expand Down Expand Up @@ -239,7 +250,13 @@ impl SignalHandler {

fn handle_sigwinch(&self) -> anyhow::Result<()> {
info!("handle_sigwinch: enter");
let mut client = protocol::Client::new(&self.socket)?;
let mut client = match protocol::Client::new(&self.socket)? {
ClientResult::JustClient(c) => c,
// At this point, we've already warned the user and they
// chose to continue anyway, so we shouldn't bother them
// again.
ClientResult::VersionMismatch { client, .. } => client,
};

let tty_size = TtySize::from_fd(0).context("getting tty size")?;
info!("handle_sigwinch: tty_size={:?}", tty_size);
Expand Down
17 changes: 17 additions & 0 deletions libshpool/src/daemon/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use shpool_protocol::{
AttachHeader, AttachReplyHeader, AttachStatus, ConnectHeader, DetachReply, DetachRequest,
KillReply, KillRequest, ListReply, ResizeReply, Session, SessionMessageDetachReply,
SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload, SessionStatus,
VersionHeader,
};
use tracing::{error, info, instrument, span, trace, warn, Level};

Expand Down Expand Up @@ -139,6 +140,22 @@ impl Server {
.set_read_timeout(Some(consts::SOCK_STREAM_TIMEOUT))
.context("setting read timout on inbound session")?;

// advertize our protocol version to the client so that it can
// warn about mismatches
protocol::encode_to(
&VersionHeader {
// We allow fake version to be injected for ease of testing.
// Otherwise we would have to resort to some heinous build
// contortions.
version: match env::var("SHPOOL_TEST__OVERRIDE_VERSION") {
Ok(fake_version) => fake_version,
Err(_) => String::from(shpool_protocol::VERSION),
},
},
&mut stream,
)
.context("writing version header")?;

let header = parse_connect_header(&mut stream).context("parsing connect header")?;

if let Err(err) = check_peer(&stream) {
Expand Down
8 changes: 6 additions & 2 deletions libshpool/src/detach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ use std::{io, path::Path};
use anyhow::{anyhow, Context};
use shpool_protocol::{ConnectHeader, DetachReply, DetachRequest};

use crate::{common, protocol};
use crate::{common, protocol, protocol::ClientResult};

pub fn run<P>(mut sessions: Vec<String>, socket: P) -> anyhow::Result<()>
where
P: AsRef<Path>,
{
let mut client = match protocol::Client::new(socket) {
Ok(c) => c,
Ok(ClientResult::JustClient(c)) => c,
Ok(ClientResult::VersionMismatch { warning, client }) => {
eprintln!("warning: {}, try restarting your daemon", warning);
client
}
Err(err) => {
let io_err = err.downcast::<io::Error>()?;
if io_err.kind() == io::ErrorKind::NotFound {
Expand Down
8 changes: 6 additions & 2 deletions libshpool/src/kill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ use std::{io, path::Path};
use anyhow::{anyhow, Context};
use shpool_protocol::{ConnectHeader, KillReply, KillRequest};

use crate::{common, protocol};
use crate::{common, protocol, protocol::ClientResult};

pub fn run<P>(mut sessions: Vec<String>, socket: P) -> anyhow::Result<()>
where
P: AsRef<Path>,
{
let mut client = match protocol::Client::new(socket) {
Ok(c) => c,
Ok(ClientResult::JustClient(c)) => c,
Ok(ClientResult::VersionMismatch { warning, client }) => {
eprintln!("warning: {}, try restarting your daemon", warning);
client
}
Err(err) => {
let io_err = err.downcast::<io::Error>()?;
if io_err.kind() == io::ErrorKind::NotFound {
Expand Down
8 changes: 6 additions & 2 deletions libshpool/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ use std::{io, path::PathBuf, time};
use anyhow::Context;
use shpool_protocol::{ConnectHeader, ListReply};

use crate::protocol;
use crate::{protocol, protocol::ClientResult};

pub fn run(socket: PathBuf) -> anyhow::Result<()> {
let mut client = match protocol::Client::new(socket) {
Ok(c) => c,
Ok(ClientResult::JustClient(c)) => c,
Ok(ClientResult::VersionMismatch { warning, client }) => {
eprintln!("warning: {}, try restarting your daemon", warning);
client
}
Err(err) => {
let io_err = err.downcast::<io::Error>()?;
if io_err.kind() == io::ErrorKind::NotFound {
Expand Down
157 changes: 148 additions & 9 deletions libshpool/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::{
cmp,
io::{self, Read, Write},
os::unix::net::UnixStream,
path::Path,
Expand All @@ -23,8 +24,8 @@ use std::{
use anyhow::{anyhow, Context};
use byteorder::{LittleEndian, ReadBytesExt as _, WriteBytesExt as _};
use serde::{Deserialize, Serialize};
use shpool_protocol::{Chunk, ChunkKind, ConnectHeader};
use tracing::{debug, error, instrument, span, trace, warn, Level};
use shpool_protocol::{Chunk, ChunkKind, ConnectHeader, VersionHeader};
use tracing::{debug, error, info, instrument, span, trace, warn, Level};

use super::{consts, tty};

Expand Down Expand Up @@ -120,19 +121,70 @@ impl<'data> ChunkExt<'data> for Chunk<'data> {
}

pub struct Client {
pub stream: UnixStream,
stream: UnixStream,
}

/// The result of creating a client, possibly with
/// flagging some issues that need to be handled.
pub enum ClientResult {
/// The created client, ready to go.
JustClient(Client),
/// There was a version mismatch between the client
/// process and the daemon process which ought to be
/// handled, though it is possible that some operations
/// will continue to work.
VersionMismatch {
/// A warning about a version mismatch that should be
/// displayed to the user.
warning: String,
/// The client, which may or may not work.
client: Client,
},
}

impl Client {
pub fn new<P: AsRef<Path>>(sock: P) -> anyhow::Result<Self> {
/// Create a new client
#[allow(clippy::new_ret_no_self)]
pub fn new<P: AsRef<Path>>(sock: P) -> anyhow::Result<ClientResult> {
let stream = UnixStream::connect(sock).context("connecting to shpool")?;
Ok(Client { stream })
}

pub fn write_connect_header(&mut self, header: ConnectHeader) -> anyhow::Result<()> {
let serialize_stream = self.stream.try_clone().context("cloning stream for reply")?;
encode_to(&header, serialize_stream).context("writing reply")?;
let daemon_version: VersionHeader = match decode_from(&stream) {
Ok(v) => v,
Err(e) => {
warn!("error parsing VersionHeader: {:?}", e);
return Ok(ClientResult::VersionMismatch {
warning: String::from("could not get daemon version"),
client: Client { stream },
});
}
};
info!("read daemon version header: {:?}", daemon_version);

match Self::version_ord(shpool_protocol::VERSION, &daemon_version.version)
.context("comparing versions")?
{
cmp::Ordering::Equal => Ok(ClientResult::JustClient(Client { stream })),
cmp::Ordering::Less => Ok(ClientResult::VersionMismatch {
warning: format!(
"client protocol (version {:?}) is older than daemon protocol (version {:?})",
shpool_protocol::VERSION,
daemon_version.version,
),
client: Client { stream },
}),
cmp::Ordering::Greater => Ok(ClientResult::VersionMismatch {
warning: format!(
"client protocol ({:?}) is newer than daemon protocol (version {:?})",
shpool_protocol::VERSION,
daemon_version.version,
),
client: Client { stream },
}),
}
}

pub fn write_connect_header(&self, header: ConnectHeader) -> anyhow::Result<()> {
encode_to(&header, &self.stream).context("writing reply")?;
Ok(())
}

Expand All @@ -144,6 +196,43 @@ impl Client {
Ok(reply)
}

/// This is essentially just PartialOrd on client version strings
/// with more descriptive errors (since PartialOrd gives an option)
/// and without having to wrap in a newtype.
fn version_ord(client_version: &str, daemon_version: &str) -> anyhow::Result<cmp::Ordering> {
let client_parts = client_version
.split('.')
.map(|p| p.parse::<i64>())
.collect::<Result<Vec<_>, _>>()
.context("parsing client version")?;
if client_parts.len() != 3 {
return Err(anyhow!(
"parsing client version: got {} parts, want 3",
client_parts.len(),
));
}

let daemon_parts = daemon_version
.split('.')
.map(|p| p.parse::<i64>())
.collect::<Result<Vec<_>, _>>()
.context("parsing daemon version")?;
if daemon_parts.len() != 3 {
return Err(anyhow!(
"parsing daemon version: got {} parts, want 3",
daemon_parts.len(),
));
}

// pre 1.0 releases flag breaking changes with their
// minor version rather than major version.
if client_parts[0] == 0 && daemon_parts[0] == 0 {
return Ok(client_parts[1].cmp(&daemon_parts[1]));
}

Ok(client_parts[0].cmp(&daemon_parts[0]))
}

/// pipe_bytes suffles bytes from std{in,out} to the unix
/// socket and back again. It is the main loop of
/// `shpool attach`.
Expand Down Expand Up @@ -317,4 +406,54 @@ mod test {
assert_eq!(c, round_tripped);
}
}

#[test]
fn version_ordering_noerr() {
use std::cmp::Ordering;

let cases = vec![
("1.0.0", "1.0.0", Ordering::Equal),
("1.0.0", "1.0.1", Ordering::Equal),
("1.0.0", "1.1.0", Ordering::Equal),
("1.0.0", "1.1.1", Ordering::Equal),
("1.0.0", "1.100.100", Ordering::Equal),
("1.0.0", "2.0.0", Ordering::Less),
("1.0.0", "2.8.0", Ordering::Less),
("1.199.0", "2.8.0", Ordering::Less),
("2.0.0", "1.0.0", Ordering::Greater),
("0.1.0", "0.1.0", Ordering::Equal),
("0.1.1", "0.1.0", Ordering::Equal),
("0.1.1", "0.1.99", Ordering::Equal),
("0.1.0", "0.2.0", Ordering::Less),
("0.1.99", "0.2.0", Ordering::Less),
("0.2.0", "0.1.0", Ordering::Greater),
];

for (lhs, rhs, ordering) in cases {
let actual_ordering =
Client::version_ord(lhs, rhs).expect("version strings to have an ordering");
assert_eq!(actual_ordering, ordering);
}
}

#[test]
fn version_ordering_err() {
let cases = vec![
("1.0.0", "1.0.0.0", "got 4 parts, want 3"),
("1.0.0.0", "1.0.0", "got 4 parts, want 3"),
("foobar", "1.0.0", "invalid digit found in string"),
("1.foobar", "1.0.0", "invalid digit found in string"),
];

for (lhs, rhs, err_substr) in cases {
if let Err(e) = Client::version_ord(lhs, rhs) {
eprintln!("ERR: {:?}", e);
eprintln!("EXPECTED SUBSTR: {}", err_substr);
let errstr = format!("{:?}", e);
assert!(errstr.contains(err_substr));
} else {
panic!("no error though we expected one");
}
}
}
}
2 changes: 1 addition & 1 deletion shpool-protocol/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shpool-protocol"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
authors = ["Ethan Pailes <pailes@google.com>"]
repository = "https://github.com/shell-pool/shpool"
Expand Down
14 changes: 13 additions & 1 deletion shpool-protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@ use serde_derive::{Deserialize, Serialize};

pub const VERSION: &str = env!("CARGO_PKG_VERSION");

/// ConnectHeader is the blob of metadata that a client transmits when it
/// The header used to advertize daemon version.
///
/// This header gets written by the daemon to every stream as
/// soon as it is opened, which allows the client to compare
/// version strings for protocol negotiation (basically just
/// deciding if the user ought to be warned about mismatched
/// versions).
#[derive(Serialize, Deserialize, Debug)]
pub struct VersionHeader {
pub version: String,
}

/// The blob of metadata that a client transmits when it
/// first connects.
///
/// It uses an enum to allow different connection types
Expand Down
Loading

0 comments on commit e62a2e8

Please sign in to comment.