Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rust-runtime to rustls 0.22 #3458

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-config/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"

# used for a usage example
hyper-rustls = { version = "0.24", features = ["webpki-tokio", "http2", "http1"] }
hyper-rustls = { version = "0.25", features = ["webpki-tokio", "http2", "http1"] }
aws-smithy-async = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-async", features = ["rt-tokio", "test-util"] }


Expand Down
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ http = "0.2.6"
# cargo does not support optional test dependencies, so to completely disable rustls
# we need to add the webpki-roots feature here.
# https://github.com/rust-lang/cargo/issues/1596
hyper-rustls = { version = "0.24", optional = true, features = ["rustls-native-certs", "http2", "webpki-roots"] }
hyper-rustls = { version = "0.25", optional = true, features = ["rustls-native-certs", "http2", "webpki-roots"] }

[dev-dependencies]
http = "0.2.4"
Expand Down
8 changes: 4 additions & 4 deletions rust-runtime/aws-smithy-http-server-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ bytes = "1.2"
futures = "0.3"
http = "0.2"
hyper = { version = "0.14.26", features = ["server", "http1", "http2", "tcp", "stream"] }
tls-listener = { version = "0.7.0", features = ["rustls", "hyper-h2"] }
rustls-pemfile = "1.0.1"
tokio-rustls = "0.24.0"
lambda_http = { version = "0.8.0" }
num_cpus = "1.13.1"
parking_lot = "0.12.1"
pin-project-lite = "0.2"
pyo3 = "0.18.2"
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime"] }
rustls-pemfile = "2"
signal-hook = { version = "0.3.14", features = ["extended-siginfo"] }
socket2 = { version = "0.5.2", features = ["all"] }
thiserror = "1.0.32"
tls-listener = { version = "0.9", features = ["rustls"] }
tokio = { version = "1.20.1", features = ["full"] }
tokio-rustls = "0.25"
tokio-stream = "0.1"
tower = { version = "0.4.13", features = ["util"] }
tracing = "0.1.36"
Expand All @@ -48,7 +48,7 @@ tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.18.0", features = ["testing", "attributes", "tokio-runtime", "unstable-streams"] }
rcgen = "0.10.0"
hyper-rustls = { version = "0.24", features = ["http2"] }
hyper-rustls = { version = "0.25", features = ["http2"] }

# PyO3 Asyncio tests cannot use Cargo's default testing harness because `asyncio`
# wants to control the main thread. So we need to use testing harness provided by `pyo3_asyncio`
Expand Down
46 changes: 7 additions & 39 deletions rust-runtime/aws-smithy-http-server-python/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::convert::Infallible;
use std::net::TcpListener as StdTcpListener;
use std::ops::Deref;
use std::process;
use std::sync::{mpsc, Arc};
use std::thread;

use aws_smithy_http_server::{
Expand All @@ -22,7 +21,6 @@ use pyo3::{prelude::*, types::IntoPyDict};
use signal_hook::{consts::*, iterator::Signals};
use socket2::Socket;
use tokio::{net::TcpListener, runtime};
use tokio_rustls::TlsAcceptor;
use tower::{util::BoxCloneService, ServiceBuilder};

use crate::{
Expand Down Expand Up @@ -257,11 +255,10 @@ event_loop.add_signal_handler(signal.SIGINT,
.build()
.expect("unable to start a new tokio runtime for this process");
rt.block_on(async move {
let addr = addr_incoming_from_socket(raw_socket);
let listener = tcp_listener_from_socket(raw_socket);

if let Some(config) = tls {
let (acceptor, acceptor_rx) = tls_config_reloader(config);
let listener = TlsListener::new(acceptor, addr, acceptor_rx);
let listener = TlsListener::new(config, listener);
let server =
hyper::Server::builder(listener).serve(IntoMakeService::new(service));

Expand All @@ -271,6 +268,8 @@ event_loop.add_signal_handler(signal.SIGINT,
tracing::error!(error = ?err, "server error");
}
} else {
let addr = AddrIncoming::from_listener(listener)
.expect("unable to create `AddrIncoming` from `TcpListener`");
let server = hyper::Server::builder(addr).serve(IntoMakeService::new(service));

tracing::trace!("started hyper server from shared socket");
Expand Down Expand Up @@ -498,43 +497,12 @@ event_loop.add_signal_handler(signal.SIGINT,
}
}

fn addr_incoming_from_socket(socket: Socket) -> AddrIncoming {
fn tcp_listener_from_socket(socket: Socket) -> TcpListener {
let std_listener: StdTcpListener = socket.into();
// StdTcpListener::from_std doesn't set O_NONBLOCK
std_listener
.set_nonblocking(true)
.expect("unable to set `O_NONBLOCK=true` on `std::net::TcpListener`");
let listener = TcpListener::from_std(std_listener)
.expect("unable to create `tokio::net::TcpListener` from `std::net::TcpListener`");
AddrIncoming::from_listener(listener)
.expect("unable to create `AddrIncoming` from `TcpListener`")
}

// Builds `TlsAcceptor` from given `config` and also creates a background task
// to reload certificates and returns a channel to receive new `TlsAcceptor`s.
fn tls_config_reloader(config: PyTlsConfig) -> (TlsAcceptor, mpsc::Receiver<TlsAcceptor>) {
let reload_dur = config.reload_duration();
let (tx, rx) = mpsc::channel();
let acceptor = TlsAcceptor::from(Arc::new(config.build().expect("invalid tls config")));

tokio::spawn(async move {
tracing::trace!(dur = ?reload_dur, "starting timer to reload tls config");
loop {
tokio::time::sleep(reload_dur).await;
tracing::trace!("reloading tls config");
match config.build() {
Ok(config) => {
let new_config = TlsAcceptor::from(Arc::new(config));
// Note on expect: `tx.send` can only fail if the receiver is dropped,
// it probably a bug if that happens
tx.send(new_config).expect("could not send new tls config")
}
Err(err) => {
tracing::error!(error = ?err, "could not reload tls config because it is invalid");
}
}
}
});

(acceptor, rx)
TcpListener::from_std(std_listener)
.expect("unable to create `tokio::net::TcpListener` from `std::net::TcpListener`")
}
44 changes: 26 additions & 18 deletions rust-runtime/aws-smithy-http-server-python/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use std::time::Duration;

use pyo3::{pyclass, pymethods};
use thiserror::Error;
use tokio_rustls::rustls::{Certificate, Error as RustTlsError, PrivateKey, ServerConfig};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::rustls::{Error as RustTlsError, ServerConfig};

pub mod listener;

Expand Down Expand Up @@ -53,7 +54,6 @@ impl PyTlsConfig {
let cert_chain = self.cert_chain()?;
let key_der = self.key_der()?;
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert_chain, key_der)?;
config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
Expand All @@ -66,18 +66,16 @@ impl PyTlsConfig {
}

/// Reads certificates from `cert_path`.
fn cert_chain(&self) -> Result<Vec<Certificate>, PyTlsConfigError> {
fn cert_chain(&self) -> Result<Vec<CertificateDer<'static>>, PyTlsConfigError> {
let file = File::open(&self.cert_path).map_err(PyTlsConfigError::CertParse)?;
let mut cert_rdr = BufReader::new(file);
Ok(rustls_pemfile::certs(&mut cert_rdr)
.map_err(PyTlsConfigError::CertParse)?
.into_iter()
.map(Certificate)
.collect())
rustls_pemfile::certs(&mut cert_rdr)
.collect::<Result<Vec<_>, _>>()
.map_err(PyTlsConfigError::CertParse)
}

/// Parses RSA or PKCS private key from `key_path`.
fn key_der(&self) -> Result<PrivateKey, PyTlsConfigError> {
fn key_der(&self) -> Result<PrivateKeyDer<'static>, PyTlsConfigError> {
let mut key_vec = Vec::new();
File::open(&self.key_path)
.and_then(|mut f| f.read_to_end(&mut key_vec))
Expand All @@ -86,16 +84,26 @@ impl PyTlsConfig {
return Err(PyTlsConfigError::EmptyKey);
}

let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice())
.map_err(PyTlsConfigError::Pkcs8Parse)?;
if !pkcs8.is_empty() {
return Ok(PrivateKey(pkcs8.remove(0)));
let mut key_slice = key_vec.as_slice();
let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_slice);
match pkcs8.next() {
Some(Ok(key)) if !key.secret_pkcs8_der().is_empty() => {
return Ok(PrivateKeyDer::from(key).clone_key())
}
Some(Ok(_)) => return Err(PyTlsConfigError::EmptyKey),
Some(Err(e)) => return Err(PyTlsConfigError::Pkcs8Parse(e)),
None => {}
}

let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice())
.map_err(PyTlsConfigError::RsaParse)?;
if !rsa.is_empty() {
return Ok(PrivateKey(rsa.remove(0)));
let mut key_slice = key_vec.as_slice();
let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_slice);
match rsa.next() {
Some(Ok(key)) if !key.secret_pkcs1_der().is_empty() => {
return Ok(PrivateKeyDer::from(key))
}
Some(Ok(_)) => return Err(PyTlsConfigError::EmptyKey),
Some(Err(e)) => return Err(PyTlsConfigError::Pkcs8Parse(e)),
None => {}
}

Err(PyTlsConfigError::EmptyKey)
Expand Down Expand Up @@ -129,7 +137,7 @@ pub enum PyTlsConfigError {
Pkcs8Parse(io::Error),
#[error("could not parse rsa keys")]
RsaParse(io::Error),
#[error("rusttls protocol error")]
#[error("rustls protocol error")]
RustTlsError(#[from] RustTlsError),
}

Expand Down
Loading
Loading