diff --git a/crates/rattler/src/package_cache.rs b/crates/rattler/src/package_cache.rs index c1bd750b3..cb2deab8c 100644 --- a/crates/rattler/src/package_cache.rs +++ b/crates/rattler/src/package_cache.rs @@ -38,6 +38,14 @@ pub struct CacheKey { name: String, version: String, build_string: String, + sha256: Option, +} + +impl CacheKey { + /// Return the sha256 hash of the package if it is known. + pub fn sha256(&self) -> Option<&str> { + self.sha256.as_deref() + } } impl From for CacheKey { @@ -46,6 +54,7 @@ impl From for CacheKey { name: pkg.name, version: pkg.version, build_string: pkg.build_string, + sha256: None, } } } @@ -56,6 +65,7 @@ impl From<&PackageRecord> for CacheKey { name: record.name.as_normalized().to_string(), version: record.version.to_string(), build_string: record.build.clone(), + sha256: record.sha256.map(|s| format!("{s:x}")).clone(), } } } diff --git a/crates/rattler_networking/Cargo.toml b/crates/rattler_networking/Cargo.toml index 73ae38958..ef64e7b73 100644 --- a/crates/rattler_networking/Cargo.toml +++ b/crates/rattler_networking/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true readme.workspace = true [features] +default = ["rustls-tls"] native-tls = ['reqwest/native-tls'] rustls-tls = ['reqwest/rustls-tls'] @@ -20,6 +21,7 @@ async-trait = { workspace = true } base64 = { workspace = true } dirs = { workspace = true } fslock = { workspace = true } +http = "0.2" itertools = { workspace = true } keyring = { workspace = true } lazy_static = { workspace = true } diff --git a/crates/rattler_networking/src/lib.rs b/crates/rattler_networking/src/lib.rs index ab84d2a73..94ea3b437 100644 --- a/crates/rattler_networking/src/lib.rs +++ b/crates/rattler_networking/src/lib.rs @@ -2,11 +2,12 @@ //! Networking utilities for Rattler, specifically authenticating requests pub use authentication_middleware::AuthenticationMiddleware; - pub use authentication_storage::{authentication::Authentication, storage::AuthenticationStorage}; +pub use mirror_middleware::{MirrorMiddleware, OciMiddleware}; pub mod authentication_middleware; pub mod authentication_storage; +pub mod mirror_middleware; pub mod retry_policies; mod redaction; diff --git a/crates/rattler_networking/src/mirror_middleware.rs b/crates/rattler_networking/src/mirror_middleware.rs new file mode 100644 index 000000000..d4639a753 --- /dev/null +++ b/crates/rattler_networking/src/mirror_middleware.rs @@ -0,0 +1,457 @@ +//! Middleware to handle mirrors +use std::{ + collections::HashMap, + sync::{ + atomic::{self, AtomicUsize}, + Arc, Mutex, + }, +}; + +use http::StatusCode; +use reqwest::{ + header::{ACCEPT, AUTHORIZATION}, + Request, Response, ResponseBuilderExt, +}; +use reqwest_middleware::{Middleware, Next, Result}; +use serde::Deserialize; +use task_local_extensions::Extensions; +use url::Url; + +#[allow(dead_code)] +/// Settings for the specific mirror (e.g. no zstd or bz2 support) +struct MirrorSettings { + no_zstd: bool, + no_bz2: bool, + no_gz: bool, + max_failures: Option, +} + +#[allow(dead_code)] +struct MirrorState { + url: Url, + + failures: AtomicUsize, + + settings: MirrorSettings, +} + +impl MirrorState { + pub fn add_failure(&self) { + self.failures.fetch_add(1, atomic::Ordering::Relaxed); + } +} + +/// Middleware to handle mirrors +pub struct MirrorMiddleware { + mirror_map: HashMap>, +} + +impl MirrorMiddleware { + /// Create a new `MirrorMiddleware` from a map of mirrors + pub fn from_map(map: HashMap>) -> Self { + let mirror_map = map + .into_iter() + .map(|(k, v)| { + let v = v + .into_iter() + .map(|url| { + let url = if url.ends_with('/') { + url + } else { + format!("{url}/") + }; + MirrorState { + url: Url::parse(&url).unwrap(), + failures: AtomicUsize::new(0), + settings: MirrorSettings { + no_zstd: false, + no_bz2: false, + no_gz: false, + max_failures: Some(3), + }, + } + }) + .collect(); + (k, v) + }) + .collect(); + + Self { mirror_map } + } +} + +fn select_mirror(mirrors: &[MirrorState]) -> &MirrorState { + let mut min_failures = usize::MAX; + let mut min_failures_index = 0; + + for (i, mirror) in mirrors.iter().enumerate() { + let failures = mirror.failures.load(atomic::Ordering::Relaxed); + if failures < min_failures { + min_failures = failures; + min_failures_index = i; + } + } + + &mirrors[min_failures_index] +} + +#[async_trait::async_trait] +impl Middleware for MirrorMiddleware { + async fn handle( + &self, + mut req: Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> Result { + let url_str = req.url().to_string(); + + for (key, mirrors) in self.mirror_map.iter() { + if let Some(url_rest) = url_str.strip_prefix(key) { + let url_rest = url_rest.trim_start_matches('/'); + // replace the key with the mirror + let selected_mirror = select_mirror(mirrors); + let selected_url = selected_mirror.url.join(url_rest).unwrap(); + *req.url_mut() = selected_url; + let res = next.run(req, extensions).await; + + // record a failure if the request failed so we can avoid the mirror in the future + match res.as_ref() { + Ok(res) if res.status().is_server_error() => selected_mirror.add_failure(), + Err(_) => selected_mirror.add_failure(), + _ => {} + } + + return res; + } + } + + // if we don't have a mirror, we don't need to do anything + next.run(req, extensions).await + } +} + +/// Middleware to handle `oci://` URLs +#[derive(Default, Debug, Clone)] +pub struct OciMiddleware { + token_cache: Arc>>, +} + +#[allow(dead_code)] +enum OciAction { + Pull, + Push, + PushPull, +} + +impl ToString for OciAction { + fn to_string(&self) -> String { + match self { + OciAction::Pull => "pull".to_string(), + OciAction::Push => "push".to_string(), + OciAction::PushPull => "push,pull".to_string(), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +struct OCIToken { + token: String, +} + +pub(crate) fn create_404_response(url: &Url, body: &str) -> Response { + Response::from( + http::response::Builder::new() + .status(StatusCode::NOT_FOUND) + .url(url.clone()) + .body(body.to_string()) + .unwrap(), + ) +} + +// [oci://ghcr.io/channel-mirrors/conda-forge]/[osx-arm64/xtensor] +async fn get_token(url: &Url, action: OciAction) -> Result { + let token_url: String = format!( + "https://{}/token?scope=repository:{}:{}", + url.host_str().unwrap(), + &url.path()[1..], + action.to_string() + ); + + tracing::info!("Requesting token from {}", token_url); + + let token = reqwest::get(token_url) + .await + .map_err(reqwest_middleware::Error::Reqwest)? + .json::() + .await? + .token; + + Ok(token) +} + +fn oci_url_with_hash(url: &Url, hash: &str) -> Url { + format!( + "https://{}/v2{}/blobs/sha256:{}", + url.host_str().unwrap(), + url.path(), + hash + ) + .parse() + .unwrap() +} + +#[derive(Debug)] +struct OciTagMediaType { + url: Url, + tag: String, + media_type: String, +} + +#[allow(dead_code)] +fn reverse_version_build_tag(tag: &str) -> String { + tag.replace("__p__", "+") + .replace("__e__", "!") + .replace("__eq__", "=") +} + +fn version_build_tag(tag: &str) -> String { + tag.replace('+', "__p__") + .replace('!', "__e__") + .replace('=', "__eq__") +} + +/// We reimplement some logic from rattler here because we don't want to introduce cyclic dependencies +fn package_to_tag(url: &Url) -> OciTagMediaType { + // get filename (last segment of path) + let filename = url.path_segments().unwrap().last().unwrap(); + + let mut res = OciTagMediaType { + url: url.clone(), + tag: "latest".to_string(), + media_type: "".to_string(), + }; + + let mut computed_filename = filename.to_string(); + + if let Some(archive_name) = filename.strip_suffix(".conda") { + let parts = archive_name.rsplitn(3, '-').collect::>(); + computed_filename = parts[2].to_string(); + res.tag = version_build_tag(&format!("{}-{}", parts[1], parts[0])); + res.media_type = "application/vnd.conda.package.v2".to_string(); + } else if let Some(archive_name) = filename.strip_suffix(".tar.bz2") { + let parts = archive_name.rsplitn(3, '-').collect::>(); + computed_filename = parts[2].to_string(); + res.tag = version_build_tag(&format!("{}-{}", parts[1], parts[0])); + res.media_type = "application/vnd.conda.package.v1".to_string(); + } else if filename.starts_with("repodata.json") { + computed_filename = "repodata.json".to_string(); + if filename == "repodata.json" { + res.media_type = "application/vnd.conda.repodata.v1+json".to_string(); + } else if filename.ends_with(".gz") { + res.media_type = "application/vnd.conda.repodata.v1+json+gzip".to_string(); + } else if filename.ends_with(".bz2") { + res.media_type = "application/vnd.conda.repodata.v1+json+bz2".to_string(); + } else if filename.ends_with(".zst") { + res.media_type = "application/vnd.conda.repodata.v1+json+zst".to_string(); + } + } + + if computed_filename.starts_with('_') { + computed_filename = format!("zzz{computed_filename}"); + } + + res.url = url.join(&computed_filename).unwrap(); + res +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct Layer { + digest: String, + #[serde(rename = "mediaType")] + media_type: String, + size: u64, + annotations: Option>, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Manifest { + schema_version: u64, + media_type: String, + layers: Vec, + config: Layer, + annotations: Option>, +} + +#[async_trait::async_trait] +impl Middleware for OciMiddleware { + async fn handle( + &self, + req: Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> Result { + // if the URL is not an OCI URL, we don't need to do anything + if req.url().scheme() != "oci" { + return next.run(req, extensions).await; + } + + let oci_info = package_to_tag(req.url()); + let url = &oci_info.url; + let token = self.token_cache.lock().unwrap().get(url).cloned(); + + let token = if let Some(token) = token { + token + } else { + let token = get_token(url, OciAction::Pull).await?; + self.token_cache + .lock() + .unwrap() + .insert(url.clone(), token.clone()); + token + }; + + let mut req = req; + req.headers_mut() + .insert(AUTHORIZATION, format!("Bearer {token}").parse().unwrap()); + + // if we know the hash, we can pull the artifact directly + // if we don't, we need to pull the manifest and then pull the artifact + if let Some(expected_sha_hash) = req + .headers() + .get("X-ExpectedSha256") + .map(|s| s.to_str().unwrap().to_string()) + { + *req.url_mut() = oci_url_with_hash(url, &expected_sha_hash); + next.run(req, extensions).await + } else { + // get the tag from the URL + // retrieve the manifest + let manifest_url = format!( + "https://{}/v2{}/manifests/{}", + url.host_str().unwrap(), + url.path(), + &oci_info.tag + ); + + let manifest = reqwest::Client::new() + .get(&manifest_url) + .header(AUTHORIZATION, format!("Bearer {token}")) + .header(ACCEPT, "application/vnd.oci.image.manifest.v1+json") + .send() + .await + .map_err(reqwest_middleware::Error::Reqwest)?; + + let manifest: Manifest = manifest.json().await?; + + let layer = if let Some(layer) = manifest + .layers + .iter() + .find(|l| l.media_type == oci_info.media_type) + { + layer + } else { + return Ok(create_404_response( + url, + "No layer available for media type", + )); + }; + + let layer_url = format!( + "https://{}/v2{}/blobs/{}", + url.host_str().unwrap(), + url.path(), + layer.digest + ); + *req.url_mut() = layer_url.parse().unwrap(); + next.run(req, extensions).await + } + } +} + +#[cfg(test)] +mod test { + use std::io::Write; + + use super::*; + + // #[tokio::test] + // async fn test_mirror_middleware() { + // let mut mirror_map = HashMap::new(); + // mirror_map.insert( + // "conda.anaconda.org".to_string(), + // vec![ + // "https://conda.anaconda.org/conda-forge".to_string(), + // "https://conda.anaconda.org/conda-forge".to_string(), + // ], + // ); + + // let middleware = MirrorMiddleware::from_map(mirror_map); + + // let client = reqwest::Client::new(); + // let mut extensions = Extensions::new(); + + // let response = middleware + // .handle( + // client.get("https://conda.anaconda.org/conda-forge/win-64/python-3.11.0-hcf16a7b_0_cpython.tar.bz2"), + // &mut extensions, + // |req, _| async { Ok(req.send().await.unwrap()) }, + // ) + // .await + // .unwrap(); + + // assert_eq!(response.status(), 200); + // } + + // test pulling an image from OCI registry + #[tokio::test] + async fn test_oci_middleware() { + let middleware = OciMiddleware::default(); + + let client = reqwest::Client::new(); + let client_with_middleware = reqwest_middleware::ClientBuilder::new(client) + .with(middleware) + .build(); + + let response = client_with_middleware + .get("oci://ghcr.io/channel-mirrors/conda-forge/osx-arm64/xtensor-0.25.0-h2ffa867_0.conda") + .header( + "X-ExpectedSha256", + "8485a64911c7011c0270b8266ab2bffa1da41c59ac4f0a48000c31d4f4a966dd", + ) + .send() + .await + .unwrap(); + + // write out to tempfile + let mut file = std::fs::File::create("./test.tar.bz2").unwrap(); + assert_eq!(response.status(), 200); + + file.write_all(&response.bytes().await.unwrap()).unwrap(); + } + + // test pulling an image from OCI registry + #[tokio::test] + async fn test_oci_middleware_repodata() { + let middleware = OciMiddleware::default(); + + let client = reqwest::Client::new(); + let client_with_middleware = reqwest_middleware::ClientBuilder::new(client) + .with(middleware) + .build(); + + let response = client_with_middleware + .get("oci://ghcr.io/channel-mirrors/conda-forge/osx-arm64/repodata.json") + .send() + .await + .unwrap(); + + // write out to tempfile + let mut file = std::fs::File::create("./test.json").unwrap(); + assert_eq!(response.status(), 200); + + file.write_all(&response.bytes().await.unwrap()).unwrap(); + } +} diff --git a/crates/rattler_package_streaming/src/reqwest/tokio.rs b/crates/rattler_package_streaming/src/reqwest/tokio.rs index 28159a570..8669fd8e8 100644 --- a/crates/rattler_package_streaming/src/reqwest/tokio.rs +++ b/crates/rattler_package_streaming/src/reqwest/tokio.rs @@ -22,9 +22,10 @@ async fn get_reader( client: reqwest_middleware::ClientWithMiddleware, ) -> Result { if url.scheme() == "file" { - let file = tokio::fs::File::open(url.to_file_path().expect("...")) - .await - .map_err(ExtractError::IoError)?; + let file = + tokio::fs::File::open(url.to_file_path().expect("Could not convert to file path")) + .await + .map_err(ExtractError::IoError)?; Ok(Either::Left(BufReader::new(file))) } else {