Skip to content

Commit

Permalink
Remove some unnecessary clones.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 26, 2024
1 parent cb74318 commit 6a60fe0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 17 deletions.
21 changes: 5 additions & 16 deletions candle-examples/examples/stable-diffusion-3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ fn main() -> Result<()> {

// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
}
} else {
let sai_repo = {
Expand All @@ -212,20 +212,12 @@ fn main() -> Result<()> {
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[model_file.clone()],
DType::F16,
&device,
)?
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};

let (context, y) = {
let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[model_file.clone()],
DType::F32,
&device,
)?
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
};
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
vb_fp16.pp("text_encoders"),
Expand Down Expand Up @@ -271,15 +263,12 @@ fn main() -> Result<()> {
};

{
let vb_vae = vb_fp16
.clone()
.rename_f(sd3_vae_vb_rename)
.pp("first_stage_model");
let vb_vae = vb_fp16.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;

// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
}
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/stable-diffusion-3/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub fn euler_sample(

let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
Expand Down

0 comments on commit 6a60fe0

Please sign in to comment.