Skip to content

Commit

Permalink
feat: Add support for CONDA_OVERRIDE_CUDA (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP authored Aug 20, 2024
1 parent 8a397a0 commit 721a6c1
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 6 deletions.
2 changes: 1 addition & 1 deletion crates/rattler_lock/src/utils/serde/pep440_map_or_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<'de> DeserializeAs<'de, Vec<Requirement>> for Pep440MapOrVec {
} else {
Some(VersionOrUrl::VersionSpecifier(spec))
},
marker: Default::default(),
marker: Option::default(),
origin: None,
})
})
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_package_streaming/src/reqwest/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use zip::result::ZipError;
/// to find the compressed data length.
/// Since we stream the package over a non seekable HTTP connection, this condition will cause an error during
/// decompression. In this case, we fallback to reading the whole data to a buffer before attempting decompression.
/// Read more in https://github.com/conda/rattler/issues/794
/// Read more in <https://github.com/conda/rattler/issues/794>
const DATA_DESCRIPTOR_ERROR_MESSAGE: &str = "The file length is not available in the local header";

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
Expand Down
141 changes: 137 additions & 4 deletions crates/rattler_virtual_packages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,51 @@ pub mod osx;

use archspec::cpu::Microarchitecture;
use once_cell::sync::OnceCell;
use rattler_conda_types::{GenericVirtualPackage, PackageName, Platform, Version};
use rattler_conda_types::{
GenericVirtualPackage, PackageName, ParseVersionError, Platform, Version,
};
use std::env;
use std::hash::{Hash, Hasher};
use std::str::FromStr;
use std::sync::Arc;

use crate::osx::ParseOsxVersionError;
use libc::DetectLibCError;
use linux::ParseLinuxVersionError;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

/// Traits for overridable virtual packages
/// Use as `Cuda::from_default_env_var.unwrap_or(Cuda::current().into()).unwrap()`
pub trait EnvOverride: Sized {
/// Parse `env_var_value`
fn from_env_var_name_with_var(
env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError>;

/// Read the environment variable and if it exists, try to parse it with [`EnvOverride::from_env_var_name_with_var`]
/// If the output is:
/// - `None`, then the environment variable did not exist,
/// - `Some(Err(None))`, then the environment variable exist but was set to zero, so the package should be disabled
/// - `Some(Ok(pkg))`, then the override was for the package.
fn from_env_var_name(env_var_name: &str) -> Option<Result<Self, Option<ParseVersionError>>> {
let var = env::var(env_var_name).ok()?;
if var.is_empty() {
Some(Err(None))
} else {
Some(Self::from_env_var_name_with_var(env_var_name, &var).map_err(Some))
}
}

/// Default name of the environment variable that overrides the virtual package.
const DEFAULT_ENV_NAME: &'static str;

/// Shortcut for `EnvOverride::from_env_var_name(EnvOverride::DEFAULT_ENV_NAME)`.
fn from_default_env_var() -> Option<Result<Self, Option<ParseVersionError>>> {
Self::from_env_var_name(Self::DEFAULT_ENV_NAME)
}
}

/// An enum that represents all virtual package types provided by this library.
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub enum VirtualPackage {
Expand Down Expand Up @@ -95,8 +131,8 @@ impl VirtualPackage {
/// Returns virtual packages detected for the current system or an error if the versions could
/// not be properly detected.
pub fn current() -> Result<&'static [Self], DetectVirtualPackageError> {
static DETECED_VIRTUAL_PACKAGES: OnceCell<Vec<VirtualPackage>> = OnceCell::new();
DETECED_VIRTUAL_PACKAGES
static DETECTED_VIRTUAL_PACKAGES: OnceCell<Vec<VirtualPackage>> = OnceCell::new();
DETECTED_VIRTUAL_PACKAGES
.get_or_try_init(try_detect_virtual_packages)
.map(Vec::as_slice)
}
Expand Down Expand Up @@ -188,6 +224,12 @@ impl From<Linux> for VirtualPackage {
}
}

impl From<Version> for Linux {
fn from(version: Version) -> Self {
Linux { version }
}
}

/// `LibC` virtual package description
#[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)]
pub struct LibC {
Expand Down Expand Up @@ -229,6 +271,20 @@ impl From<LibC> for VirtualPackage {
}
}

impl EnvOverride for LibC {
const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_GLIBC";

fn from_env_var_name_with_var(
_env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self {
family: "glibc".into(),
version,
})
}
}

/// Cuda virtual package description
#[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)]
pub struct Cuda {
Expand All @@ -243,6 +299,23 @@ impl Cuda {
}
}

impl From<Version> for Cuda {
fn from(version: Version) -> Self {
Self { version }
}
}

impl EnvOverride for Cuda {
fn from_env_var_name_with_var(
_env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self { version })
}

const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_CUDA";
}

impl From<Cuda> for GenericVirtualPackage {
fn from(cuda: Cuda) -> Self {
GenericVirtualPackage {
Expand Down Expand Up @@ -359,7 +432,7 @@ impl From<Archspec> for GenericVirtualPackage {
GenericVirtualPackage {
name: PackageName::new_unchecked("__archspec"),
version: Version::major(1),
build_string: archspec.spec.name().to_string(),
build_string: archspec.spec.name().into(),
}
}
}
Expand Down Expand Up @@ -403,13 +476,73 @@ impl From<Osx> for VirtualPackage {
}
}

impl From<Version> for Osx {
fn from(version: Version) -> Self {
Self { version }
}
}

impl EnvOverride for Osx {
fn from_env_var_name_with_var(
_env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self { version })
}

const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_OSX";
}

#[cfg(test)]
mod test {
use std::env;
use std::str::FromStr;

use rattler_conda_types::Version;

use crate::Cuda;
use crate::EnvOverride;
use crate::LibC;
use crate::Osx;
use crate::VirtualPackage;

#[test]
fn doesnt_crash() {
let virtual_packages = VirtualPackage::current().unwrap();
println!("{virtual_packages:?}");
}
#[test]
fn parse_libc() {
let v = "1.23";
let res = LibC {
version: Version::from_str(v).unwrap(),
family: "glibc".into(),
};
env::set_var(LibC::DEFAULT_ENV_NAME, v);
assert_eq!(LibC::from_default_env_var(), Some(Ok(res)));
env::set_var(LibC::DEFAULT_ENV_NAME, "");
assert_eq!(LibC::from_default_env_var(), Some(Err(None)));
env::remove_var(LibC::DEFAULT_ENV_NAME);
assert_eq!(LibC::from_default_env_var(), None);
}

#[test]
fn parse_cuda() {
let v = "1.234";
let res = Cuda {
version: Version::from_str(v).unwrap(),
};
env::set_var(Cuda::DEFAULT_ENV_NAME, v);
assert_eq!(Cuda::from_default_env_var(), Some(Ok(res)));
}

#[test]
fn parse_osx() {
let v = "2.345";
let res = Osx {
version: Version::from_str(v).unwrap(),
};
env::set_var(Osx::DEFAULT_ENV_NAME, v);
assert_eq!(Osx::from_default_env_var(), Some(Ok(res)));
}
}

0 comments on commit 721a6c1

Please sign in to comment.