Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sd3.5 medium and MMDiT-X #2587

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions candle-examples/examples/stable-diffusion-3/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3/3.5

![](assets/stable-diffusion-3.jpg)

*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*, generated by Stable Diffusion 3 Medium

Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.

- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)

Stable Diffusion 3.5 is a family of text-to-image models with latest improvements:
- [announcement blog post](https://stability.ai/news/introducing-stable-diffusion-3-5)

It has three variants:
- [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) @ 8.1b params, with scaled and slightly modified MMDiT architecture.
- [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) distilled version that enables 4-step inference.
- [Stable Diffusion 3.5 Medium](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) @ 2.5b params, with improved MMDiT-X architecture.

## Getting access to the weights

The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting [the repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account.
The weights of Stable Diffusion 3/3.5 is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the repos on HuggingFace Hub to gain access to the weights for your HuggingFace account.

To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli):

Expand All @@ -27,10 +35,12 @@ On the first run, the weights will be automatically downloaded from the Huggingf

```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \
--height 1024 --width 1024 \
--which 3-medium --height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
```

To use different models, changed the value of `--which` option. (Possible values: `3-medium`, `3.5-large`, `3.5-large-turbo` and `3.5-medium`).

To display other options available,

```shell
Expand All @@ -45,7 +55,7 @@ cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- -

## Performance Benchmark

Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
Below benchmark is done with Stable Diffusion 3 Medium by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).

[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).

Expand Down
44 changes: 37 additions & 7 deletions candle-examples/examples/stable-diffusion-3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ enum Which {
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
#[value(name = "3.5-medium")]
V3_5Medium,
}

impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo => true,
Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true,
}
}
}
Expand Down Expand Up @@ -117,36 +119,59 @@ fn main() -> Result<()> {
let default_inference_steps = match which {
Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4,
Which::V3_5Medium => 28,
Which::V3Medium => 28,
};
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3_5Medium => 4.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);

let api = hf_hub::api::sync::Api::new()?;
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo = {
let sai_repo_for_text_encoders = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",

// Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually
// placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo.
// To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors
// under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions
// to get the monolithic text encoders. This is not a trivial task.
// Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium,
// which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors.
// so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository.
// TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders.
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let sai_repo_for_mmdit = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?;
let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3_5Medium => "sd3.5_medium.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo.get(model_file)?
sai_repo_for_mmdit.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
Expand All @@ -157,7 +182,12 @@ fn main() -> Result<()> {
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
(MMDiTConfig::sd3_5_large(), triple, vb)
match which {
Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb),
Which::V3Medium => unreachable!(),
}
} else {
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
Expand Down
191 changes: 177 additions & 14 deletions candle-transformers/src/models/mmdit/blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ impl Module for LayerNormNoAffine {

impl DiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
// {'hidden_size': 1536, 'num_heads': 24}
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
Expand Down Expand Up @@ -103,6 +102,117 @@ impl DiTBlock {
}
}

pub struct SelfAttnModulateIntermediates {
gate_msa: Tensor,
shift_mlp: Tensor,
scale_mlp: Tensor,
gate_mlp: Tensor,
gate_msa2: Tensor,
}

pub struct SelfAttnDiTBlock {
norm1: LayerNormNoAffine,
attn: AttnProjections,
attn2: AttnProjections,
norm2: LayerNormNoAffine,
mlp: Mlp,
ada_ln_modulation: nn::Sequential,
}

impl SelfAttnDiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
let mlp_ratio = 4;
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
let n_mods = 9;
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
hidden_size,
n_mods * hidden_size,
vb.pp("adaLN_modulation.1"),
)?);

Ok(Self {
norm1,
attn,
attn2,
norm2,
mlp,
ada_ln_modulation,
})
}

pub fn pre_attention(
&self,
x: &Tensor,
c: &Tensor,
) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
let modulation = self.ada_ln_modulation.forward(c)?;
let chunks = modulation.chunk(9, D::Minus1)?;
let (
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = (
chunks[0].clone(),
chunks[1].clone(),
chunks[2].clone(),
chunks[3].clone(),
chunks[4].clone(),
chunks[5].clone(),
chunks[6].clone(),
chunks[7].clone(),
chunks[8].clone(),
);

let norm_x = self.norm1.forward(x)?;
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
let qkv = self.attn.pre_attention(&modulated_x)?;

let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
let qkv2 = self.attn2.pre_attention(&modulated_x2)?;

Ok((
qkv,
qkv2,
SelfAttnModulateIntermediates {
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
},
))
}

pub fn post_attention(
&self,
attn: &Tensor,
attn2: &Tensor,
x: &Tensor,
mod_interm: &SelfAttnModulateIntermediates,
) -> Result<Tensor> {
let attn_out = self.attn.post_attention(attn)?;
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
let attn_out2 = self.attn2.post_attention(attn2)?;
let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;

let norm_x = self.norm2.forward(&x)?;
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
let mlp_out = self.mlp.forward(&modulated_x)?;
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
Ok(x)
}
}

pub struct QkvOnlyDiTBlock {
norm1: LayerNormNoAffine,
attn: QkvOnlyAttnProjections,
Expand Down Expand Up @@ -190,14 +300,18 @@ fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
}

pub struct JointBlock {
pub trait JointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
}

pub struct MMDiTJointBlock {
x_block: DiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}

impl JointBlock {
impl MMDiTJointBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
Expand All @@ -214,8 +328,10 @@ impl JointBlock {
use_flash_attn,
})
}
}

pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
impl JointBlock for MMDiTJointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) =
Expand All @@ -228,6 +344,49 @@ impl JointBlock {
}
}

pub struct MMDiTXJointBlock {
x_block: SelfAttnDiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}

impl MMDiTXJointBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;

Ok(Self {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
}

impl JointBlock for MMDiTXJointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) =
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
let x_out = self
.x_block
.post_attention(&x_attn, &x_attn2, x, &x_interm)?;
Ok((context_out, x_out))
}
}

pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock,
context_block: QkvOnlyDiTBlock,
Expand Down Expand Up @@ -309,26 +468,30 @@ fn joint_attn(
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
};

let (batch_size, seqlen, _) = qkv.q.dims3()?;
let seqlen = qkv.q.dim(1)?;
let attn = attn(&qkv, num_heads, use_flash_attn)?;
let context_qkv_seqlen = context_qkv.q.dim(1)?;
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;

Ok((context_attn, x_attn))
}

fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
let batch_size = qkv.q.dim(0)?;
let seqlen = qkv.q.dim(1)?;
let qkv = Qkv {
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
v: qkv.v,
v: qkv.v.clone(),
};

let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();

let attn = if use_flash_attn {
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
} else {
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
};

let attn = attn.reshape((batch_size, seqlen, ()))?;
let context_qkv_seqlen = context_qkv.q.dim(1)?;
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;

Ok((context_attn, x_attn))
attn.reshape((batch_size, seqlen, ()))
}
Loading
Loading