-
Notifications
You must be signed in to change notification settings - Fork 943
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the flux model for image generation. (#2390)
* Add the flux autoencoder. * Add the encoder down-blocks. * Upsampling in the decoder. * Sketch the flow matching model. * More flux model. * Add some of the positional embeddings. * Add the rope embeddings. * Add the sampling functions. * Add the flux example. * Fix the T5 bits. * Proper T5 tokenizer. * Clip encoder path fix. * Get the clip embeddings. * No configurable weights in layer norm. * More weights related fixes. * Yet another shape fix. * DType fix. * Fix a couple more shape issues. * DType fixes. * Fix the latent dims. * Fix more shape issues. * Autoencoder fixes. * Get some generations out. * Bugfix. * T5 padding. * Clippy fix. * Add the decode only mode. * Fix. * More fixes. * Finally get some generations to work. * Add readme.
- Loading branch information
1 parent
0fcb40b
commit 19db6b9
Showing
8 changed files
with
1,346 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# candle-flux: image generation with latent rectified flow transformers | ||
|
||
![rusty robot holding a candle](./assets/flux-robot.jpg) | ||
|
||
Flux is a 12B rectified flow transformer capable of generating images from text | ||
descriptions, | ||
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell), | ||
[github](https://github.com/black-forest-labs/flux), | ||
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/). | ||
|
||
|
||
## Running the model | ||
|
||
```bash | ||
cargo run --features cuda --example flux -r -- \ | ||
--height 1024 --width 1024 | ||
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k" | ||
``` | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
use candle_transformers::models::{clip, flux, t5}; | ||
|
||
use anyhow::{Error as E, Result}; | ||
use candle::{IndexOp, Module, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use clap::Parser; | ||
use tokenizers::Tokenizer; | ||
|
||
#[derive(Parser)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// The prompt to be used for image generation. | ||
#[arg(long, default_value = "A rusty robot walking on a beach")] | ||
prompt: String, | ||
|
||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
/// The height in pixels of the generated image. | ||
#[arg(long)] | ||
height: Option<usize>, | ||
|
||
/// The width in pixels of the generated image. | ||
#[arg(long)] | ||
width: Option<usize>, | ||
|
||
#[arg(long)] | ||
decode_only: Option<String>, | ||
} | ||
|
||
fn run(args: Args) -> Result<()> { | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let Args { | ||
prompt, | ||
cpu, | ||
height, | ||
width, | ||
tracing, | ||
decode_only, | ||
} = args; | ||
let width = width.unwrap_or(1360); | ||
let height = height.unwrap_or(768); | ||
|
||
let _guard = if tracing { | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
|
||
let api = hf_hub::api::sync::Api::new()?; | ||
let bf_repo = api.repo(hf_hub::Repo::model( | ||
"black-forest-labs/FLUX.1-schnell".to_string(), | ||
)); | ||
let device = candle_examples::device(cpu)?; | ||
let dtype = device.bf16_default_to_f32(); | ||
let img = match decode_only { | ||
None => { | ||
let t5_emb = { | ||
let repo = api.repo(hf_hub::Repo::with_revision( | ||
"google/t5-v1_1-xxl".to_string(), | ||
hf_hub::RepoType::Model, | ||
"refs/pr/2".to_string(), | ||
)); | ||
let model_file = repo.get("model.safetensors")?; | ||
let vb = | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; | ||
let config_filename = repo.get("config.json")?; | ||
let config = std::fs::read_to_string(config_filename)?; | ||
let config: t5::Config = serde_json::from_str(&config)?; | ||
let mut model = t5::T5EncoderModel::load(vb, &config)?; | ||
let tokenizer_filename = api | ||
.model("lmz/mt5-tokenizers".to_string()) | ||
.get("t5-v1_1-xxl.tokenizer.json")?; | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
let mut tokens = tokenizer | ||
.encode(prompt.as_str(), true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
tokens.resize(256, 0); | ||
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; | ||
println!("{input_token_ids}"); | ||
model.forward(&input_token_ids)? | ||
}; | ||
println!("T5\n{t5_emb}"); | ||
let clip_emb = { | ||
let repo = api.repo(hf_hub::Repo::model( | ||
"openai/clip-vit-large-patch14".to_string(), | ||
)); | ||
let model_file = repo.get("model.safetensors")?; | ||
let vb = | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; | ||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json | ||
let config = clip::text_model::ClipTextConfig { | ||
vocab_size: 49408, | ||
projection_dim: 768, | ||
activation: clip::text_model::Activation::QuickGelu, | ||
intermediate_size: 3072, | ||
embed_dim: 768, | ||
max_position_embeddings: 77, | ||
pad_with: None, | ||
num_hidden_layers: 12, | ||
num_attention_heads: 12, | ||
}; | ||
let model = | ||
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?; | ||
let tokenizer_filename = repo.get("tokenizer.json")?; | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
let tokens = tokenizer | ||
.encode(prompt.as_str(), true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; | ||
println!("{input_token_ids}"); | ||
model.forward(&input_token_ids)? | ||
}; | ||
println!("CLIP\n{clip_emb}"); | ||
let img = { | ||
let model_file = bf_repo.get("flux1-schnell.sft")?; | ||
let vb = | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; | ||
let cfg = flux::model::Config::schnell(); | ||
let model = flux::model::Flux::new(&cfg, vb)?; | ||
|
||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; | ||
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; | ||
println!("{state:?}"); | ||
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell | ||
println!("{timesteps:?}"); | ||
flux::sampling::denoise( | ||
&model, | ||
&state.img, | ||
&state.img_ids, | ||
&state.txt, | ||
&state.txt_ids, | ||
&state.vec, | ||
×teps, | ||
4., | ||
)? | ||
}; | ||
flux::sampling::unpack(&img, height, width)? | ||
} | ||
Some(file) => { | ||
let mut st = candle::safetensors::load(file, &device)?; | ||
st.remove("img").unwrap().to_dtype(dtype)? | ||
} | ||
}; | ||
println!("latent img\n{img}"); | ||
|
||
let img = { | ||
let model_file = bf_repo.get("ae.sft")?; | ||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; | ||
let cfg = flux::autoencoder::Config::schnell(); | ||
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?; | ||
model.decode(&img)? | ||
}; | ||
println!("img\n{img}"); | ||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; | ||
candle_examples::save_image(&img.i(0)?, "out.jpg")?; | ||
Ok(()) | ||
} | ||
|
||
fn main() -> Result<()> { | ||
let args = Args::parse(); | ||
run(args) | ||
} |
Oops, something went wrong.