Skip to content

Commit

Permalink
Move the flags to a more appropriate place.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 25, 2023
1 parent 6ae1cc3 commit 990bfb3
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
7 changes: 0 additions & 7 deletions candle-examples/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,6 @@ fn compute_cap() -> Result<usize> {

println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");

// TODO: Having to specify these manually on all binary packages would be annoying. We should
// check e.g. in the cc codebase how to ensure that these flags get propagated from the
// flash-attn crate to the binary crate that uses it.
#[cfg(feature = "flash-attn")]
println!("cargo:rustc-link-lib=dylib=cudart");
#[cfg(feature = "flash-attn")]
println!("cargo:rustc-link-lib=dylib=stdc++");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
Expand Down
5 changes: 0 additions & 5 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,6 @@ struct Args {
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let op = candle_flash_attn::FlashHdim32Sm80;
let t = Tensor::zeros((3, 2), DType::F16, &device)?;
let t2 = Tensor::zeros((3, 2), DType::F16, &device)?;
let t3 = Tensor::zeros((3, 2), DType::F16, &device)?;
t.custom_op3(&t2, &t3, op);
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
Expand Down
2 changes: 2 additions & 0 deletions candle-flash-attn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ fn main() -> Result<()> {
}
println!("cargo:rustc-link-search={}", out_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");

/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
Expand Down

0 comments on commit 990bfb3

Please sign in to comment.