diff --git a/candle-examples/examples/flux/README.md b/candle-examples/examples/flux/README.md new file mode 100644 index 0000000000..528f058e38 --- /dev/null +++ b/candle-examples/examples/flux/README.md @@ -0,0 +1,19 @@ +# candle-flux: image generation with latent rectified flow transformers + +![rusty robot holding a candle](./assets/flux-robot.jpg) + +Flux is a 12B rectified flow transformer capable of generating images from text +descriptions, +[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell), +[github](https://github.com/black-forest-labs/flux), +[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/). + + +## Running the model + +```bash +cargo run --features cuda --example flux -r -- \ + --height 1024 --width 1024 + --prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k" +``` + diff --git a/candle-examples/examples/flux/assets/flux-robot.jpg b/candle-examples/examples/flux/assets/flux-robot.jpg new file mode 100644 index 0000000000..f715743346 Binary files /dev/null and b/candle-examples/examples/flux/assets/flux-robot.jpg differ diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs new file mode 100644 index 0000000000..826174bc69 --- /dev/null +++ b/candle-examples/examples/flux/main.rs @@ -0,0 +1,182 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::{clip, flux, t5}; + +use anyhow::{Error as E, Result}; +use candle::{IndexOp, Module, Tensor}; +use candle_nn::VarBuilder; +use clap::Parser; +use tokenizers::Tokenizer; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg(long, default_value = "A rusty robot walking on a beach")] + prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The height in pixels of the generated image. + #[arg(long)] + height: Option, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option, + + #[arg(long)] + decode_only: Option, +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + cpu, + height, + width, + tracing, + decode_only, + } = args; + let width = width.unwrap_or(1360); + let height = height.unwrap_or(768); + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let api = hf_hub::api::sync::Api::new()?; + let bf_repo = api.repo(hf_hub::Repo::model( + "black-forest-labs/FLUX.1-schnell".to_string(), + )); + let device = candle_examples::device(cpu)?; + let dtype = device.bf16_default_to_f32(); + let img = match decode_only { + None => { + let t5_emb = { + let repo = api.repo(hf_hub::Repo::with_revision( + "google/t5-v1_1-xxl".to_string(), + hf_hub::RepoType::Model, + "refs/pr/2".to_string(), + )); + let model_file = repo.get("model.safetensors")?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let mut model = t5::T5EncoderModel::load(vb, &config)?; + let tokenizer_filename = api + .model("lmz/mt5-tokenizers".to_string()) + .get("t5-v1_1-xxl.tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let mut tokens = tokenizer + .encode(prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + tokens.resize(256, 0); + let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; + println!("{input_token_ids}"); + model.forward(&input_token_ids)? + }; + println!("T5\n{t5_emb}"); + let clip_emb = { + let repo = api.repo(hf_hub::Repo::model( + "openai/clip-vit-large-patch14".to_string(), + )); + let model_file = repo.get("model.safetensors")?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + let config = clip::text_model::ClipTextConfig { + vocab_size: 49408, + projection_dim: 768, + activation: clip::text_model::Activation::QuickGelu, + intermediate_size: 3072, + embed_dim: 768, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 12, + }; + let model = + clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?; + let tokenizer_filename = repo.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; + println!("{input_token_ids}"); + model.forward(&input_token_ids)? + }; + println!("CLIP\n{clip_emb}"); + let img = { + let model_file = bf_repo.get("flux1-schnell.sft")?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let cfg = flux::model::Config::schnell(); + let model = flux::model::Flux::new(&cfg, vb)?; + + let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; + let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + println!("{state:?}"); + let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell + println!("{timesteps:?}"); + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + }; + flux::sampling::unpack(&img, height, width)? + } + Some(file) => { + let mut st = candle::safetensors::load(file, &device)?; + st.remove("img").unwrap().to_dtype(dtype)? + } + }; + println!("latent img\n{img}"); + + let img = { + let model_file = bf_repo.get("ae.sft")?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let cfg = flux::autoencoder::Config::schnell(); + let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?; + model.decode(&img)? + }; + println!("img\n{img}"); + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; + candle_examples::save_image(&img.i(0)?, "out.jpg")?; + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + run(args) +} diff --git a/candle-transformers/src/models/flux/autoencoder.rs b/candle-transformers/src/models/flux/autoencoder.rs new file mode 100644 index 0000000000..8c2aebbdc4 --- /dev/null +++ b/candle-transformers/src/models/flux/autoencoder.rs @@ -0,0 +1,440 @@ +use candle::{Result, Tensor, D}; +use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder}; + +// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9 +#[derive(Debug, Clone)] +pub struct Config { + pub resolution: usize, + pub in_channels: usize, + pub ch: usize, + pub out_ch: usize, + pub ch_mult: Vec, + pub num_res_blocks: usize, + pub z_channels: usize, + pub scale_factor: f64, + pub shift_factor: f64, +} + +impl Config { + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47 + pub fn dev() -> Self { + Self { + resolution: 256, + in_channels: 3, + ch: 128, + out_ch: 3, + ch_mult: vec![1, 2, 4, 4], + num_res_blocks: 2, + z_channels: 16, + scale_factor: 0.3611, + shift_factor: 0.1159, + } + } + + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79 + pub fn schnell() -> Self { + Self { + resolution: 256, + in_channels: 3, + ch: 128, + out_ch: 3, + ch_mult: vec![1, 2, 4, 4], + num_res_blocks: 2, + z_channels: 16, + scale_factor: 0.3611, + shift_factor: 0.1159, + } + } +} + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v) +} + +#[derive(Debug, Clone)] +struct AttnBlock { + q: Conv2d, + k: Conv2d, + v: Conv2d, + proj_out: Conv2d, + norm: GroupNorm, +} + +impl AttnBlock { + fn new(in_c: usize, vb: VarBuilder) -> Result { + let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?; + let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?; + let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?; + let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?; + let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?; + Ok(Self { + q, + k, + v, + proj_out, + norm, + }) + } +} + +impl candle::Module for AttnBlock { + fn forward(&self, xs: &Tensor) -> Result { + let init_xs = xs; + let xs = xs.apply(&self.norm)?; + let q = xs.apply(&self.q)?; + let k = xs.apply(&self.k)?; + let v = xs.apply(&self.v)?; + let (b, c, h, w) = q.dims4()?; + let q = q.flatten_from(2)?.t()?.unsqueeze(1)?; + let k = k.flatten_from(2)?.t()?.unsqueeze(1)?; + let v = v.flatten_from(2)?.t()?.unsqueeze(1)?; + let xs = scaled_dot_product_attention(&q, &k, &v)?; + let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?; + xs.apply(&self.proj_out)? + init_xs + } +} + +#[derive(Debug, Clone)] +struct ResnetBlock { + norm1: GroupNorm, + conv1: Conv2d, + norm2: GroupNorm, + conv2: Conv2d, + nin_shortcut: Option, +} + +impl ResnetBlock { + fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?; + let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?; + let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?; + let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?; + let nin_shortcut = if in_c == out_c { + None + } else { + Some(conv2d( + in_c, + out_c, + 1, + Default::default(), + vb.pp("nin_shortcut"), + )?) + }; + Ok(Self { + norm1, + conv1, + norm2, + conv2, + nin_shortcut, + }) + } +} + +impl candle::Module for ResnetBlock { + fn forward(&self, xs: &Tensor) -> Result { + let h = xs + .apply(&self.norm1)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv1)? + .apply(&self.norm2)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv2)?; + match self.nin_shortcut.as_ref() { + None => xs + h, + Some(c) => xs.apply(c)? + h, + } + } +} + +#[derive(Debug, Clone)] +struct Downsample { + conv: Conv2d, +} + +impl Downsample { + fn new(in_c: usize, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl candle::Module for Downsample { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?; + let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?; + xs.apply(&self.conv) + } +} + +#[derive(Debug, Clone)] +struct Upsample { + conv: Conv2d, +} + +impl Upsample { + fn new(in_c: usize, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl candle::Module for Upsample { + fn forward(&self, xs: &Tensor) -> Result { + let (_, _, h, w) = xs.dims4()?; + xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv) + } +} + +#[derive(Debug, Clone)] +struct DownBlock { + block: Vec, + downsample: Option, +} + +#[derive(Debug, Clone)] +pub struct Encoder { + conv_in: Conv2d, + mid_block_1: ResnetBlock, + mid_attn_1: AttnBlock, + mid_block_2: ResnetBlock, + norm_out: GroupNorm, + conv_out: Conv2d, + down: Vec, +} + +impl Encoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let mut block_in = cfg.ch; + let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?; + + let mut down = Vec::with_capacity(cfg.ch_mult.len()); + let vb_d = vb.pp("down"); + for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() { + let mut block = Vec::with_capacity(cfg.num_res_blocks); + let vb_d = vb_d.pp(i_level); + let vb_b = vb_d.pp("block"); + let in_ch_mult = if i_level == 0 { + 1 + } else { + cfg.ch_mult[i_level - 1] + }; + block_in = cfg.ch * in_ch_mult; + let block_out = cfg.ch * ch_mult; + for i_block in 0..cfg.num_res_blocks { + let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?; + block.push(b); + block_in = block_out; + } + let downsample = if i_level != cfg.ch_mult.len() - 1 { + Some(Downsample::new(block_in, vb_d.pp("downsample"))?) + } else { + None + }; + let block = DownBlock { block, downsample }; + down.push(block) + } + + let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?; + let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?; + let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?; + let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?; + let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?; + Ok(Self { + conv_in, + mid_block_1, + mid_attn_1, + mid_block_2, + norm_out, + conv_out, + down, + }) + } +} + +impl candle_nn::Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut h = xs.apply(&self.conv_in)?; + for block in self.down.iter() { + for b in block.block.iter() { + h = h.apply(b)? + } + if let Some(ds) = block.downsample.as_ref() { + h = h.apply(ds)? + } + } + h.apply(&self.mid_block_1)? + .apply(&self.mid_attn_1)? + .apply(&self.mid_block_2)? + .apply(&self.norm_out)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv_out) + } +} + +#[derive(Debug, Clone)] +struct UpBlock { + block: Vec, + upsample: Option, +} + +#[derive(Debug, Clone)] +pub struct Decoder { + conv_in: Conv2d, + mid_block_1: ResnetBlock, + mid_attn_1: AttnBlock, + mid_block_2: ResnetBlock, + norm_out: GroupNorm, + conv_out: Conv2d, + up: Vec, +} + +impl Decoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1); + let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?; + let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?; + let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?; + let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?; + + let mut up = Vec::with_capacity(cfg.ch_mult.len()); + let vb_u = vb.pp("up"); + for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() { + let block_out = cfg.ch * ch_mult; + let vb_u = vb_u.pp(i_level); + let vb_b = vb_u.pp("block"); + let mut block = Vec::with_capacity(cfg.num_res_blocks + 1); + for i_block in 0..=cfg.num_res_blocks { + let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?; + block.push(b); + block_in = block_out; + } + let upsample = if i_level != 0 { + Some(Upsample::new(block_in, vb_u.pp("upsample"))?) + } else { + None + }; + let block = UpBlock { block, upsample }; + up.push(block) + } + up.reverse(); + + let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?; + let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?; + Ok(Self { + conv_in, + mid_block_1, + mid_attn_1, + mid_block_2, + norm_out, + conv_out, + up, + }) + } +} + +impl candle_nn::Module for Decoder { + fn forward(&self, xs: &Tensor) -> Result { + let h = xs.apply(&self.conv_in)?; + let mut h = h + .apply(&self.mid_block_1)? + .apply(&self.mid_attn_1)? + .apply(&self.mid_block_2)?; + for block in self.up.iter().rev() { + for b in block.block.iter() { + h = h.apply(b)? + } + if let Some(us) = block.upsample.as_ref() { + h = h.apply(us)? + } + } + h.apply(&self.norm_out)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv_out) + } +} + +#[derive(Debug, Clone)] +pub struct DiagonalGaussian { + sample: bool, + chunk_dim: usize, +} + +impl DiagonalGaussian { + pub fn new(sample: bool, chunk_dim: usize) -> Result { + Ok(Self { sample, chunk_dim }) + } +} + +impl candle_nn::Module for DiagonalGaussian { + fn forward(&self, xs: &Tensor) -> Result { + let chunks = xs.chunk(2, self.chunk_dim)?; + if self.sample { + let std = (&chunks[1] * 0.5)?.exp()?; + &chunks[0] + (std * chunks[0].randn_like(0., 1.))? + } else { + Ok(chunks[0].clone()) + } + } +} + +#[derive(Debug, Clone)] +pub struct AutoEncoder { + encoder: Encoder, + decoder: Decoder, + reg: DiagonalGaussian, + shift_factor: f64, + scale_factor: f64, +} + +impl AutoEncoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let decoder = Decoder::new(cfg, vb.pp("decoder"))?; + let reg = DiagonalGaussian::new(true, 1)?; + Ok(Self { + encoder, + decoder, + reg, + scale_factor: cfg.scale_factor, + shift_factor: cfg.shift_factor, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let z = xs.apply(&self.encoder)?.apply(&self.reg)?; + (z - self.shift_factor)? * self.scale_factor + } + pub fn decode(&self, xs: &Tensor) -> Result { + let xs = ((xs / self.scale_factor)? + self.shift_factor)?; + xs.apply(&self.decoder) + } +} + +impl candle::Module for AutoEncoder { + fn forward(&self, xs: &Tensor) -> Result { + self.decode(&self.encode(xs)?) + } +} diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs new file mode 100644 index 0000000000..763fa90da1 --- /dev/null +++ b/candle-transformers/src/models/flux/mod.rs @@ -0,0 +1,3 @@ +pub mod autoencoder; +pub mod model; +pub mod sampling; diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs new file mode 100644 index 0000000000..aa00077e66 --- /dev/null +++ b/candle-transformers/src/models/flux/model.rs @@ -0,0 +1,582 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder}; + +// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 +#[derive(Debug, Clone)] +pub struct Config { + pub in_channels: usize, + pub vec_in_dim: usize, + pub context_in_dim: usize, + pub hidden_size: usize, + pub mlp_ratio: f64, + pub num_heads: usize, + pub depth: usize, + pub depth_single_blocks: usize, + pub axes_dim: Vec, + pub theta: usize, + pub qkv_bias: bool, + pub guidance_embed: bool, +} + +impl Config { + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32 + pub fn dev() -> Self { + Self { + in_channels: 64, + vec_in_dim: 768, + context_in_dim: 4096, + hidden_size: 3072, + mlp_ratio: 4.0, + num_heads: 24, + depth: 19, + depth_single_blocks: 38, + axes_dim: vec![16, 56, 56], + theta: 10_000, + qkv_bias: true, + guidance_embed: true, + } + } + + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64 + pub fn schnell() -> Self { + Self { + in_channels: 64, + vec_in_dim: 768, + context_in_dim: 4096, + hidden_size: 3072, + mlp_ratio: 4.0, + num_heads: 24, + depth: 19, + depth_single_blocks: 38, + axes_dim: vec![16, 56, 56], + theta: 10_000, + qkv_bias: true, + guidance_embed: false, + } + } +} + +fn layer_norm(dim: usize, vb: VarBuilder) -> Result { + let ws = Tensor::ones(dim, vb.dtype(), vb.device())?; + Ok(LayerNorm::new_no_bias(ws, 1e-6)) +} + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let mut batch_dims = q.dims().to_vec(); + batch_dims.pop(); + batch_dims.pop(); + let q = q.flatten_to(batch_dims.len() - 1)?; + let k = k.flatten_to(batch_dims.len() - 1)?; + let v = v.flatten_to(batch_dims.len() - 1)?; + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?; + batch_dims.push(attn_scores.dim(D::Minus2)?); + batch_dims.push(attn_scores.dim(D::Minus1)?); + attn_scores.reshape(batch_dims) +} + +fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result { + if dim % 2 == 1 { + candle::bail!("dim {dim} is odd") + } + let dev = pos.device(); + let theta = theta as f64; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?; + let inv_freq = inv_freq.to_dtype(pos.dtype())?; + let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?; + let cos = freqs.cos()?; + let sin = freqs.sin()?; + let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?; + let (b, n, d, _ij) = out.dims4()?; + out.reshape((b, n, d, 2, 2)) +} + +fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result { + let dims = x.dims(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?; + let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?; + (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) +} + +fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { + let q = apply_rope(q, pe)?.contiguous()?; + let k = apply_rope(k, pe)?.contiguous()?; + let x = scaled_dot_product_attention(&q, &k, v)?; + x.transpose(1, 2)?.flatten_from(2) +} + +fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { + const TIME_FACTOR: f64 = 1000.; + const MAX_PERIOD: f64 = 10000.; + if dim % 2 == 1 { + candle::bail!("{dim} is odd") + } + let dev = t.device(); + let half = dim / 2; + let t = (t * TIME_FACTOR)?; + let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?; + let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; + let args = t + .unsqueeze(1)? + .to_dtype(candle::DType::F32)? + .broadcast_mul(&freqs.unsqueeze(0)?)?; + let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?; + Ok(emb) +} + +#[derive(Debug, Clone)] +pub struct EmbedNd { + #[allow(unused)] + dim: usize, + theta: usize, + axes_dim: Vec, +} + +impl EmbedNd { + fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { + Self { + dim, + theta, + axes_dim, + } + } +} + +impl candle::Module for EmbedNd { + fn forward(&self, ids: &Tensor) -> Result { + let n_axes = ids.dim(D::Minus1)?; + let mut emb = Vec::with_capacity(n_axes); + for idx in 0..n_axes { + let r = rope( + &ids.get_on_dim(D::Minus1, idx)?, + self.axes_dim[idx], + self.theta, + )?; + emb.push(r) + } + let emb = Tensor::cat(&emb, 2)?; + emb.unsqueeze(1) + } +} + +#[derive(Debug, Clone)] +pub struct MlpEmbedder { + in_layer: Linear, + out_layer: Linear, +} + +impl MlpEmbedder { + fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { + let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?; + let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?; + Ok(Self { + in_layer, + out_layer, + }) + } +} + +impl candle::Module for MlpEmbedder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) + } +} + +#[derive(Debug, Clone)] +pub struct QkNorm { + query_norm: RmsNorm, + key_norm: RmsNorm, +} + +impl QkNorm { + fn new(dim: usize, vb: VarBuilder) -> Result { + let query_norm = vb.get(dim, "query_norm.scale")?; + let query_norm = RmsNorm::new(query_norm, 1e-6); + let key_norm = vb.get(dim, "key_norm.scale")?; + let key_norm = RmsNorm::new(key_norm, 1e-6); + Ok(Self { + query_norm, + key_norm, + }) + } +} + +#[derive(Debug, Clone)] +pub struct Modulation { + lin: Linear, + multiplier: usize, +} + +impl Modulation { + fn new(dim: usize, double: bool, vb: VarBuilder) -> Result { + let multiplier = if double { 6 } else { 3 }; + let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?; + Ok(Self { lin, multiplier }) + } + + fn forward(&self, vec_: &Tensor) -> Result> { + vec_.silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(self.multiplier, D::Minus1) + } +} + +#[derive(Debug, Clone)] +pub struct SelfAttention { + qkv: Linear, + norm: QkNorm, + proj: Linear, + num_heads: usize, +} + +impl SelfAttention { + fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { + let head_dim = dim / num_heads; + let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?; + Ok(Self { + qkv, + norm, + proj, + num_heads, + }) + } + + fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let qkv = xs.apply(&self.qkv)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + Ok((q, k, v)) + } + + #[allow(unused)] + fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result { + let (q, k, v) = self.qkv(xs)?; + attention(&q, &k, &v, pe)?.apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + lin1: Linear, + lin2: Linear, +} + +impl Mlp { + fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { + let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?; + let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?; + Ok(Self { lin1, lin2 }) + } +} + +impl candle::Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + } +} + +#[derive(Debug, Clone)] +pub struct DoubleStreamBlock { + img_mod: Modulation, + img_norm1: LayerNorm, + img_attn: SelfAttention, + img_norm2: LayerNorm, + img_mlp: Mlp, + txt_mod: Modulation, + txt_norm1: LayerNorm, + txt_attn: SelfAttention, + txt_norm2: LayerNorm, + txt_mlp: Mlp, +} + +impl DoubleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?; + let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; + let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; + let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; + let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; + let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?; + let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; + let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; + let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; + let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?; + Ok(Self { + img_mod, + img_norm1, + img_attn, + img_norm2, + img_mlp, + txt_mod, + txt_norm1, + txt_attn, + txt_norm2, + txt_mlp, + }) + } + + fn forward( + &self, + img: &Tensor, + txt: &Tensor, + vec_: &Tensor, + pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate + let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate + let img_modulated = img.apply(&self.img_norm1)?; + let img_modulated = img_modulated + .broadcast_mul(&(&img_mod[1] + 1.)?)? + .broadcast_add(&img_mod[0])?; + let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; + + let txt_modulated = txt.apply(&self.txt_norm1)?; + let txt_modulated = txt_modulated + .broadcast_mul(&(&txt_mod[1] + 1.)?)? + .broadcast_add(&txt_mod[0])?; + let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; + + let q = Tensor::cat(&[txt_q, img_q], 2)?; + let k = Tensor::cat(&[txt_k, img_k], 2)?; + let v = Tensor::cat(&[txt_v, img_v], 2)?; + + let attn = attention(&q, &k, &v, pe)?; + let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; + let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; + + let img = (img + + img_attn + .apply(&self.img_attn.proj)? + .broadcast_mul(&img_mod[2]))?; + let img = (&img + + &img_mod[5].broadcast_mul( + &img.apply(&self.img_norm2)? + .broadcast_mul(&(&img_mod[4] + 1.0)?)? + .broadcast_add(&img_mod[3])? + .apply(&self.img_mlp)?, + )?)?; + + let txt = (txt + + txt_attn + .apply(&self.txt_attn.proj)? + .broadcast_mul(&txt_mod[2]))?; + let txt = (&txt + + &txt_mod[5].broadcast_mul( + &txt.apply(&self.txt_norm2)? + .broadcast_mul(&(&txt_mod[4] + 1.0)?)? + .broadcast_add(&txt_mod[3])? + .apply(&self.txt_mlp)?, + )?)?; + + Ok((img, txt)) + } +} + +#[derive(Debug, Clone)] +pub struct SingleStreamBlock { + linear1: Linear, + linear2: Linear, + norm: QkNorm, + pre_norm: LayerNorm, + modulation: Modulation, + h_sz: usize, + mlp_sz: usize, + num_heads: usize, +} + +impl SingleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let head_dim = h_sz / cfg.num_heads; + let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; + let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; + let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?; + Ok(Self { + linear1, + linear2, + norm, + pre_norm, + modulation, + h_sz, + mlp_sz, + num_heads: cfg.num_heads, + }) + } + + fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { + let mod_ = self.modulation.forward(vec_)?; + let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]); + let x_mod = xs + .apply(&self.pre_norm)? + .broadcast_mul(&(scale + 1.0)?)? + .broadcast_add(shift)?; + let x_mod = x_mod.apply(&self.linear1)?; + let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + let attn = attention(&q, &k, &v, pe)?; + let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; + xs + gate.broadcast_mul(&output) + } +} + +#[derive(Debug, Clone)] +pub struct LastLayer { + norm_final: LayerNorm, + linear: Linear, + ada_ln_modulation: Linear, +} + +impl LastLayer { + fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { + let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; + let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; + let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; + Ok(Self { + norm_final, + linear, + ada_ln_modulation, + }) + } + + fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result { + let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?; + let (shift, scale) = (&chunks[0], &chunks[1]); + let xs = xs + .apply(&self.norm_final)? + .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)? + .broadcast_add(&shift.unsqueeze(1)?)?; + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +pub struct Flux { + img_in: Linear, + txt_in: Linear, + time_in: MlpEmbedder, + vector_in: MlpEmbedder, + guidance_in: Option, + pe_embedder: EmbedNd, + double_blocks: Vec, + single_blocks: Vec, + final_layer: LastLayer, +} + +impl Flux { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?; + let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?; + let mut double_blocks = Vec::with_capacity(cfg.depth); + let vb_d = vb.pp("double_blocks"); + for idx in 0..cfg.depth { + let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?; + double_blocks.push(db) + } + let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks); + let vb_s = vb.pp("single_blocks"); + for idx in 0..cfg.depth_single_blocks { + let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?; + single_blocks.push(sb) + } + let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?; + let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?; + let guidance_in = if cfg.guidance_embed { + let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?; + Some(mlp) + } else { + None + }; + let final_layer = + LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?; + let pe_dim = cfg.hidden_size / cfg.num_heads; + let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec()); + Ok(Self { + img_in, + txt_in, + time_in, + vector_in, + guidance_in, + pe_embedder, + double_blocks, + single_blocks, + final_layer, + }) + } + + #[allow(clippy::too_many_arguments)] + pub fn forward( + &self, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + timesteps: &Tensor, + y: &Tensor, + guidance: Option<&Tensor>, + ) -> Result { + if txt.rank() != 3 { + candle::bail!("unexpected shape for txt {:?}", txt.shape()) + } + if img.rank() != 3 { + candle::bail!("unexpected shape for img {:?}", img.shape()) + } + let dtype = img.dtype(); + let pe = { + let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; + ids.apply(&self.pe_embedder)? + }; + let mut txt = txt.apply(&self.txt_in)?; + let mut img = img.apply(&self.img_in)?; + let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; + let vec_ = match (self.guidance_in.as_ref(), guidance) { + (Some(g_in), Some(guidance)) => { + (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))? + } + _ => vec_, + }; + let vec_ = (vec_ + y.apply(&self.vector_in))?; + + // Double blocks + for block in self.double_blocks.iter() { + (img, txt) = block.forward(&img, &txt, &vec_, &pe)? + } + // Single blocks + let mut img = Tensor::cat(&[&txt, &img], 1)?; + for block in self.single_blocks.iter() { + img = block.forward(&img, &vec_, &pe)?; + } + let img = img.i((.., txt.dim(1)?..))?; + self.final_layer.forward(&img, &vec_) + } +} diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs new file mode 100644 index 0000000000..89b9a95382 --- /dev/null +++ b/candle-transformers/src/models/flux/sampling.rs @@ -0,0 +1,119 @@ +use candle::{Device, Result, Tensor}; + +pub fn get_noise( + num_samples: usize, + height: usize, + width: usize, + device: &Device, +) -> Result { + let height = (height + 15) / 16 * 2; + let width = (width + 15) / 16 * 2; + Tensor::randn(0f32, 1., (num_samples, 16, height, width), device) +} + +#[derive(Debug, Clone)] +pub struct State { + pub img: Tensor, + pub img_ids: Tensor, + pub txt: Tensor, + pub txt_ids: Tensor, + pub vec: Tensor, +} + +impl State { + pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result { + let dtype = img.dtype(); + let (bs, c, h, w) = img.dims4()?; + let dev = img.device(); + let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; // (b, c, h, ph, w, pw) + let img = img.permute((0, 2, 4, 1, 3, 5))?; // (b, h, w, c, ph, pw) + let img = img.reshape((bs, h / 2 * w / 2, c * 4))?; + let img_ids = Tensor::stack( + &[ + Tensor::full(0u32, (h / 2, w / 2), dev)?, + Tensor::arange(0u32, h as u32 / 2, dev)? + .reshape(((), 1))? + .broadcast_as((h / 2, w / 2))?, + Tensor::arange(0u32, w as u32 / 2, dev)? + .reshape((1, ()))? + .broadcast_as((h / 2, w / 2))?, + ], + 2, + )? + .to_dtype(dtype)?; + let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?; + let img_ids = img_ids.repeat((bs, 1, 1))?; + let txt = t5_emb.repeat(bs)?; + let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?; + let vec = clip_emb.repeat(bs)?; + Ok(Self { + img, + img_ids, + txt, + txt_ids, + vec, + }) + } +} + +fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 { + let e = mu.exp(); + e / (e + (1. / t - 1.).powf(sigma)) +} + +/// `shift` is a triple `(image_seq_len, base_shift, max_shift)`. +pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec { + let timesteps: Vec = (0..=num_steps) + .map(|v| v as f64 / num_steps as f64) + .rev() + .collect(); + match shift { + None => timesteps, + Some((image_seq_len, y1, y2)) => { + let (x1, x2) = (256., 4096.); + let m = (y2 - y1) / (x2 - x1); + let b = y1 - m * x1; + let mu = m * image_seq_len as f64 + b; + timesteps + .into_iter() + .map(|v| time_shift(mu, 1., v)) + .collect() + } + } +} + +pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result { + let (b, _h_w, c_ph_pw) = xs.dims3()?; + let height = (height + 15) / 16; + let width = (width + 15) / 16; + xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw) + .permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw) + .reshape((b, c_ph_pw / 4, height * 2, width * 2)) +} + +#[allow(clippy::too_many_arguments)] +pub fn denoise( + model: &super::model::Flux, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + vec_: &Tensor, + timesteps: &[f64], + guidance: f64, +) -> Result { + let b_sz = img.dim(0)?; + let dev = img.device(); + let guidance = Tensor::full(guidance as f32, b_sz, dev)?; + let mut img = img.clone(); + for window in timesteps.windows(2) { + let (t_curr, t_prev) = match window { + [a, b] => (a, b), + _ => continue, + }; + let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?; + let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?; + img = (img + pred * (t_prev - t_curr))? + } + Ok(img) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 836fdc7cce..fa35011994 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -17,6 +17,7 @@ pub mod efficientvit; pub mod encodec; pub mod eva2; pub mod falcon; +pub mod flux; pub mod gemma; pub mod hiera; pub mod jina_bert;