Skip to content

Commit

Permalink
Specialize localhost session
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung committed Nov 23, 2023
1 parent c0f2bdc commit ec2669f
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 106 deletions.
25 changes: 19 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "pegasus-ssh"
description = "Pegasus: A Multi-Node SSH Command Runner"
authors = ["Jae-Won Chung <jaewon.chung.cs@gmail.com>"]
version = "1.1.3"
version = "1.2.0"
edition = "2021"
repository = "https://github.com/jaywonchung/pegasus"
license = "MIT"
Expand All @@ -28,6 +28,8 @@ handlebars_misc_helpers = { version = "0.13.0", features = ["string", "json"] }
colored = "2"
colourado = "0.2.0"
memchr = "2.4.1"
thiserror = "1.0.50"
async-trait = "0.1.74"

[profile.release]
panic = "abort"
9 changes: 9 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use thiserror::Error;

#[derive(Error, Debug)]
pub enum PegasusError {
#[error("failed to connect SSH session or execute SSH command")]
SshError(#[from] openssh::Error),
#[error("failed to execute local command")]
LocalCommandError(#[from] std::io::Error),
}
34 changes: 33 additions & 1 deletion src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
//! One SSH connection is created for one `Host`. Each connection will run commands in parallel with
//! other connections in its own tokio task.

use colourado::Color;
use itertools::sorted;
use std::collections::HashMap;
use std::fmt;
use std::fs::File;
use std::str::FromStr;

use colored::*;
use colourado::Color;
use openssh::{KnownHosts, Session as SSHSession};
use serde::Deserialize;
use void::Void;

use crate::error::PegasusError;
use crate::serde::string_or_mapping;
use crate::session::{LocalSession, RemoteSession, Session};

#[derive(Debug, Clone)]
pub struct Host {
Expand All @@ -32,6 +35,35 @@ impl Host {
}
}

/// Connects to the host over SSH.
pub async fn connect(
&self,
color: Color,
) -> Result<Box<dyn Session + Send + Sync>, PegasusError> {
let colorhost = self.prettify(color);
// Localhost does not need an SSH connection.
if self.is_localhost() {
eprintln!("{} Just spawning subprocess for localhost.", colorhost);
Ok(Box::new(LocalSession::new(colorhost)))
} else {
let session = match SSHSession::connect_mux(&self.hostname, KnownHosts::Add).await {
Ok(session) => session,
Err(e) => {
eprintln!("{} Failed to connect to host: {:?}", colorhost, e);
return Err(PegasusError::SshError(e));
}
};
eprintln!("{} Connected to host.", colorhost);
Ok(Box::new(RemoteSession::new(colorhost, session)))
}
}

/// Returns true if the host is localhost.
pub fn is_localhost(&self) -> bool {
let hostname = self.hostname.trim().to_lowercase();
hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1"
}

/// For pretty-printing the host name.
/// Surrounds with brackets and colors it with a random color.
pub fn prettify(&self, color: Color) -> ColoredString {
Expand Down
28 changes: 12 additions & 16 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mod sync;
mod session;
// Provides utility for std::io::Writer
mod writer;
// Error handling.
mod error;

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
Expand All @@ -26,12 +28,12 @@ use tokio::sync::{broadcast, Barrier, Mutex};
use tokio::time;

use crate::config::{Config, Mode};
use crate::error::PegasusError;
use crate::host::get_hosts;
use crate::job::Cmd;
use crate::session::Session;
use crate::sync::LockedFile;

async fn run_broadcast(cli: &Config) -> Result<(), openssh::Error> {
async fn run_broadcast(cli: &Config) -> Result<(), PegasusError> {
let hosts = get_hosts(&cli.hosts_file);
let num_hosts = hosts.len();

Expand Down Expand Up @@ -68,11 +70,9 @@ async fn run_broadcast(cli: &Config) -> Result<(), openssh::Error> {
let end_barrier = Arc::clone(&end_barrier);
let errored = Arc::clone(&errored);
let print_period = cli.print_period;
// Open a new SSH session with the host.
let session = host.connect(color).await?;
tasks.push(tokio::spawn(async move {
// Open a new SSH session with the host.
let session = Session::connect(host, color)
.await
.expect("Failed to connect to host");
// Handlebars registry for filling in parameters.
let mut registry = Handlebars::new();
handlebars_misc_helpers::register(&mut registry);
Expand All @@ -82,14 +82,13 @@ async fn run_broadcast(cli: &Config) -> Result<(), openssh::Error> {
// session object to be dropped, and everyting gracefully
// terminated.
while let Ok(cmd) = command_rx.recv().await {
let cmd = cmd.fill_template(&mut registry, &session.host);
let cmd = cmd.fill_template(&mut registry, &host);
let result = session.run(cmd, print_period).await;
if result.is_err() || result.unwrap().code() != Some(0) {
errored.store(true, Ordering::Relaxed);
}
end_barrier.wait().await;
}
session.close().await;
}));
}

Expand Down Expand Up @@ -134,7 +133,7 @@ async fn run_broadcast(cli: &Config) -> Result<(), openssh::Error> {
Ok(())
}

async fn run_queue(cli: &Config) -> Result<(), openssh::Error> {
async fn run_queue(cli: &Config) -> Result<(), PegasusError> {
let hosts = get_hosts(&cli.hosts_file);
let num_hosts = hosts.len();

Expand Down Expand Up @@ -162,11 +161,9 @@ async fn run_queue(cli: &Config) -> Result<(), openssh::Error> {
command_txs.push(command_tx);
let notify_tx = notify_tx.clone();
let print_period = cli.print_period;
// Open a new SSH session with the host.
let session = host.connect(color).await?;
tasks.push(tokio::spawn(async move {
// Open a new SSH session with the host.
let session = Session::connect(host, color)
.await
.expect("Failed to connect to host");
// Handlebars registry for filling in parameters.
let mut registry = Handlebars::new();
handlebars_misc_helpers::register(&mut registry);
Expand All @@ -181,13 +178,12 @@ async fn run_queue(cli: &Config) -> Result<(), openssh::Error> {
// Receive and run the command.
match command_rx.recv_async().await {
Ok(cmd) => {
let cmd = cmd.fill_template(&mut registry, &session.host);
let cmd = cmd.fill_template(&mut registry, &host);
let _ = session.run(cmd, print_period).await;
}
Err(_) => break,
};
}
session.close().await;
}));
}

Expand Down Expand Up @@ -257,7 +253,7 @@ async fn run_lock(cli: &Config) {
}

#[tokio::main]
async fn main() -> Result<(), openssh::Error> {
async fn main() -> Result<(), PegasusError> {
let cli = Config::parse();

match cli.mode {
Expand Down
Loading

0 comments on commit ec2669f

Please sign in to comment.