diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index 51624775f..5b40a7c4f 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -11,7 +11,9 @@ use rattler_conda_types::{ Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, PackageRecord, Platform, PrefixRecord, RepoDataRecord, Version, }; -use rattler_networking::{AuthenticatedClient, AuthenticationStorage}; +use rattler_networking::{ + retry_policies::default_retry_policy, AuthenticatedClient, AuthenticationStorage, +}; use rattler_repodata_gateway::fetch::{ CacheResult, DownloadProgress, FetchRepoDataError, FetchRepoDataOptions, }; @@ -397,10 +399,11 @@ async fn execute_operation( async { // Make sure the package is available in the package cache. let result = package_cache - .get_or_fetch_from_url( + .get_or_fetch_from_url_with_retry( &install_record.package_record, install_record.url.clone(), download_client.clone(), + default_retry_policy(), ) .map_ok(|cache_dir| Some((install_record.clone(), cache_dir))) .map_err(anyhow::Error::from) diff --git a/crates/rattler/Cargo.toml b/crates/rattler/Cargo.toml index 048314721..50c9c346c 100644 --- a/crates/rattler/Cargo.toml +++ b/crates/rattler/Cargo.toml @@ -56,3 +56,8 @@ rand = "0.8.5" rstest = "0.18.1" tracing-test = { version = "0.2.4" } insta = { version = "1.30.0", features = ["yaml"] } + +tokio = { version = "1.29.1", features = ["macros", "rt-multi-thread"] } +axum = "0.6.18" +tower-http = { version = "0.4.1", features = ["fs"] } +tower = { version = "0.4.13", default-features = false, features = ["util"] } diff --git a/crates/rattler/src/package_cache.rs b/crates/rattler/src/package_cache.rs index b611900b0..bad1f9aff 100644 --- a/crates/rattler/src/package_cache.rs +++ b/crates/rattler/src/package_cache.rs @@ -1,11 +1,16 @@ //! This module provides functionality to cache extracted Conda packages. See [`PackageCache`]. use crate::validation::validate_package_directory; +use chrono::Utc; use fxhash::FxHashMap; use itertools::Itertools; -use rattler_conda_types::package::ArchiveIdentifier; -use rattler_conda_types::PackageRecord; -use rattler_networking::AuthenticatedClient; +use rattler_conda_types::{package::ArchiveIdentifier, PackageRecord}; +use rattler_networking::{ + retry_policies::{DoNotRetryPolicy, RetryDecision, RetryPolicy}, + AuthenticatedClient, +}; +use rattler_package_streaming::ExtractError; +use reqwest::StatusCode; use std::error::Error; use std::{ fmt::{Display, Formatter}, @@ -182,12 +187,71 @@ impl PackageCache { pkg: impl Into, url: Url, client: AuthenticatedClient, + ) -> Result { + self.get_or_fetch_from_url_with_retry(pkg, url, client, DoNotRetryPolicy) + .await + } + + /// Returns the directory that contains the specified package. + /// + /// This is a convenience wrapper around `get_or_fetch` which fetches the package from the given + /// URL if the package could not be found in the cache. + pub async fn get_or_fetch_from_url_with_retry( + &self, + pkg: impl Into, + url: Url, + client: AuthenticatedClient, + retry_policy: impl RetryPolicy + Send + 'static, ) -> Result { self.get_or_fetch(pkg, move |destination| async move { - tracing::debug!("downloading {} to {}", &url, destination.display()); - rattler_package_streaming::reqwest::tokio::extract(client, url, &destination) - .await - .map(|_| ()) + let mut current_try = 0; + loop { + current_try += 1; + tracing::debug!("downloading {} to {}", &url, destination.display()); + let result = rattler_package_streaming::reqwest::tokio::extract( + client.clone(), + url.clone(), + &destination, + ) + .await; + + // Extract any potential error + let Err(err) = result else { return Ok(()); }; + + // Only retry on certain errors. + if !matches!( + &err, + ExtractError::IoError(_) | ExtractError::CouldNotCreateDestination(_) + ) && !matches!(&err, ExtractError::ReqwestError(err) if + err.is_timeout() || + err.is_connect() || + err + .status() + .map(|status| status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS || status == StatusCode::REQUEST_TIMEOUT) + .unwrap_or(false) + ) { + return Err(err); + } + + // Determine whether or not to retry based on the retry policy + let execute_after = match retry_policy.should_retry(current_try) { + RetryDecision::Retry { execute_after } => execute_after, + RetryDecision::DoNotRetry => return Err(err), + }; + let duration = (execute_after - Utc::now()).to_std().expect("the retry duration is out of range"); + + // Wait for a second to let the remote service restore itself. This increases the + // chance of success. + tracing::warn!( + "failed to download and extract {} to {}: {}. Retry #{}, Sleeping {:?} until the next attempt...", + &url, + destination.display(), + err, + current_try, + duration + ); + tokio::time::sleep(duration).await; + } }) .await } @@ -240,9 +304,26 @@ where mod test { use super::PackageCache; use crate::{get_test_data_dir, validation::validate_package_directory}; + use assert_matches::assert_matches; + use axum::{ + extract::State, + http::{Request, StatusCode}, + middleware, + middleware::Next, + response::Response, + routing::get_service, + Router, + }; use rattler_conda_types::package::{ArchiveIdentifier, PackageFile, PathsJson}; - use std::{fs::File, path::Path}; + use rattler_networking::{ + retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder}, + AuthenticatedClient, + }; + use std::{fs::File, net::SocketAddr, path::Path, sync::Arc}; use tempfile::tempdir; + use tokio::sync::Mutex; + use tower_http::services::ServeDir; + use url::Url; #[tokio::test] pub async fn test_package_cache() { @@ -284,4 +365,95 @@ mod test { // archive. assert_eq!(current_paths, paths); } + + /// A helper middleware function that fails the first two requests. + async fn fail_the_first_two_requests( + State(count): State>>, + req: Request, + next: Next, + ) -> Result { + let count = { + let mut count = count.lock().await; + *count += 1; + *count + }; + + println!("Running middleware for request #{count} for {}", req.uri()); + if count <= 2 { + println!("Discarding request!"); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + + // requires the http crate to get the header name + Ok(next.run(req).await) + } + + #[tokio::test] + pub async fn test_flaky_package_cache() { + let static_dir = get_test_data_dir(); + + // Construct a service that serves raw files from the test directory + let service = get_service(ServeDir::new(static_dir)); + + // Construct a router that returns data from the static dir but fails the first try. + let request_count = Arc::new(Mutex::new(0)); + let router = + Router::new() + .route_service("/*key", service) + .layer(middleware::from_fn_with_state( + request_count.clone(), + fail_the_first_two_requests, + )); + + // Construct the server that will listen on localhost but with a *random port*. The random + // port is very important because it enables creating multiple instances at the same time. + // We need this to be able to run tests in parallel. + let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); + let server = axum::Server::bind(&addr).serve(router.into_make_service()); + + // Get the address of the server so we can bind to it at a later stage. + let addr = server.local_addr(); + + // Spawn the server. + tokio::spawn(server); + + let packages_dir = tempdir().unwrap(); + let cache = PackageCache::new(packages_dir.path()); + + let archive_name = "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2"; + let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap(); + + // Do the first request without + let result = cache + .get_or_fetch_from_url_with_retry( + ArchiveIdentifier::try_from_filename(archive_name).unwrap(), + server_url.join(archive_name).unwrap(), + AuthenticatedClient::default(), + DoNotRetryPolicy, + ) + .await; + + // First request without retry policy should fail + assert_matches!(result, Err(_)); + { + let request_count_lock = request_count.lock().await; + assert_eq!(*request_count_lock, 1, "Expected there to be 1 request"); + } + + // The second one should fail after the 2nd try + let result = cache + .get_or_fetch_from_url_with_retry( + ArchiveIdentifier::try_from_filename(archive_name).unwrap(), + server_url.join(archive_name).unwrap(), + AuthenticatedClient::default(), + ExponentialBackoffBuilder::default().build_with_max_retries(3), + ) + .await; + + assert!(result.is_ok()); + { + let request_count_lock = request_count.lock().await; + assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests"); + } + } } diff --git a/crates/rattler_networking/Cargo.toml b/crates/rattler_networking/Cargo.toml index 0c2e7f111..9d43358a3 100644 --- a/crates/rattler_networking/Cargo.toml +++ b/crates/rattler_networking/Cargo.toml @@ -21,6 +21,7 @@ keyring = "2.0.4" lazy_static = "1.4.0" libc = "0.2.147" reqwest = { version = "0.11.18", features = ["blocking"], default-features = false} +retry-policies = { version = "0.2.0", default-features = false } serde = "1.0.171" serde_json = "1.0.102" thiserror = "1.0.43" diff --git a/crates/rattler_networking/src/lib.rs b/crates/rattler_networking/src/lib.rs index 009f0e535..6da7abfb3 100644 --- a/crates/rattler_networking/src/lib.rs +++ b/crates/rattler_networking/src/lib.rs @@ -8,6 +8,7 @@ pub use authentication_storage::{authentication::Authentication, storage::Authen use reqwest::{Client, IntoUrl, Method, Url}; pub mod authentication_storage; +pub mod retry_policies; /// A client that can be used to make authenticated requests, based on the [`reqwest::Client`] #[derive(Clone)] diff --git a/crates/rattler_networking/src/retry_policies.rs b/crates/rattler_networking/src/retry_policies.rs new file mode 100644 index 000000000..bd3fca2e8 --- /dev/null +++ b/crates/rattler_networking/src/retry_policies.rs @@ -0,0 +1,23 @@ +//! Reexports the trait [`RetryPolicy`] from the `retry_policies` crate as well as all +//! implementations. +//! +//! This module also provides the [`DoNotRetryPolicy`] which is useful if you do not want to retry +//! anything. + +pub use retry_policies::{policies::*, Jitter, RetryDecision, RetryPolicy}; + +/// A simple [`RetryPolicy`] that just never retries. +pub struct DoNotRetryPolicy; +impl RetryPolicy for DoNotRetryPolicy { + fn should_retry(&self, _: u32) -> RetryDecision { + RetryDecision::DoNotRetry + } +} + +/// Returns the default retry policy that can be used . +/// +/// This is useful if you just do not care about a retry policy and you just want something +/// sensible. Note that the behavior of what is "sensible" might change over time. +pub fn default_retry_policy() -> ExponentialBackoff { + ExponentialBackoff::builder().build_with_max_retries(3) +}