From 0ec5ebcec429fe2bb85a6a7f780509bb1831b024 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 26 Oct 2023 20:00:50 +0100 Subject: [PATCH 01/22] Use the hub model file when possible. (#1190) * Use the hub model file when possible. * And add a mention in the main readme. --- README.md | 5 ++- candle-examples/examples/jina-bert/README.md | 45 ++++++++++++++++++++ candle-examples/examples/jina-bert/main.rs | 28 +++++++++--- 3 files changed, 71 insertions(+), 7 deletions(-) create mode 100644 candle-examples/examples/jina-bert/README.md diff --git a/README.md b/README.md index eb7c189b82..8c076ec708 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ These online demos run entirely in your browser: - [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation. - [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation. - [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation. -- [Blip](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning. +- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning. We also provide a some command line based examples using state of the art models: @@ -96,7 +96,8 @@ We also provide a some command line based examples using state of the art models - [Whisper](./candle-examples/examples/whisper/): speech recognition model. -- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings. +- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/), + [JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings. - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation). diff --git a/candle-examples/examples/jina-bert/README.md b/candle-examples/examples/jina-bert/README.md new file mode 100644 index 0000000000..02afbaa98f --- /dev/null +++ b/candle-examples/examples/jina-bert/README.md @@ -0,0 +1,45 @@ +# candle-jina-bert + +Jina-Bert is a general large language model with a context size of 8192, [model +card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example +it can be used for two different tasks: +- Compute sentence embeddings for a prompt. +- Compute similarities between a set of sentences. + + +## Sentence embeddings + +Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights +are downloaded from the hub on the first run. + +```bash +cargo run --example jina-bert --release -- --prompt "Here is a test sentence" + +> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355], +> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807], +> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087], +> ... +> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811], +> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830], +> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]] +> Tensor[[1, 7, 768], f32] +``` + +## Similarities + +In this example, Jina-Bert is used to compute the sentence embeddings for a set of +sentences (hardcoded in the examples). Then cosine similarities are computed for +each sentence pair and they are reported by decreasing values, hence the first +reported pair contains the two sentences that have the highest similarity score. +The sentence embeddings are computed using average pooling through all the +sentence tokens, including some potential padding. + +```bash +cargo run --example jina-bert --release + +> score: 0.94 'The new movie is awesome' 'The new movie is so great' +> score: 0.81 'The cat sits outside' 'The cat plays in the garden' +> score: 0.78 'I love pasta' 'Do you like pizza?' +> score: 0.68 'I love pasta' 'The new movie is awesome' +> score: 0.67 'A man is playing guitar' 'A woman watches TV' +``` diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs index ffde777d9a..d959d4cb19 100644 --- a/candle-examples/examples/jina-bert/main.rs +++ b/candle-examples/examples/jina-bert/main.rs @@ -35,19 +35,37 @@ struct Args { normalize_embeddings: bool, #[arg(long)] - tokenizer: String, + tokenizer: Option, #[arg(long)] - model: String, + model: Option, } impl Args { fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { + use hf_hub::{api::sync::Api, Repo, RepoType}; + let model = match &self.model { + Some(model_file) => std::path::PathBuf::from(model_file), + None => Api::new()? + .repo(Repo::new( + "jinaai/jina-embeddings-v2-base-en".to_string(), + RepoType::Model, + )) + .get("model.safetensors")?, + }; + let tokenizer = match &self.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => Api::new()? + .repo(Repo::new( + "sentence-transformers/all-MiniLM-L6-v2".to_string(), + RepoType::Model, + )) + .get("tokenizer.json")?, + }; let device = candle_examples::device(self.cpu)?; let config = Config::v2_base(); - let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?; - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? }; + let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let model = BertModel::new(vb, &config)?; Ok((model, tokenizer)) } From 70d06ab4b0065576e779a628fc024ef46003cdbc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 05:57:08 +0100 Subject: [PATCH 02/22] Add support for the phi-hermes finetuned model. (#1192) --- candle-examples/examples/phi/main.rs | 14 +++++++++++--- candle-transformers/src/models/mixformer.rs | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 9401299a9c..720a4441eb 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -124,6 +124,7 @@ enum WhichModel { #[value(name = "1.5")] V1_5, PuffinPhiV2, + PhiHermes, } #[derive(Parser, Debug)] @@ -224,7 +225,9 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(), + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + "lmz/candle-quantized-phi".to_string() + } } } } @@ -238,7 +241,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::PuffinPhiV2 => "main".to_string(), + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), } } } @@ -248,7 +251,9 @@ fn main() -> Result<()> { Some(file) => std::path::PathBuf::from(file), None => match args.model { WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?, - WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?, + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + repo.get("tokenizer-puffin-phi-v2.json")? + } }, }; let filename = match args.weight_file { @@ -259,11 +264,13 @@ fn main() -> Result<()> { WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, WhichModel::V1_5 => repo.get("model-q4k.gguf")?, WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?, + WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?, } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?, WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?, + WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?, } } } @@ -276,6 +283,7 @@ fn main() -> Result<()> { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), + WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; let (model, device) = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 33aefbfe1a..e822ca1464 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -73,6 +73,23 @@ impl Config { pad_vocab_size_multiple: 64, } } + + // https://huggingface.co/teknium/Phi-Hermes-1.3B/blob/main/config.json + pub fn phi_hermes_1_3b() -> Self { + Self { + vocab_size: 50304, + n_positions: 2048, + n_embd: 2048, + n_layer: 24, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2048 / 32), + activation_function: Activation::NewGelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } } #[derive(Debug, Clone)] From 9b1158b3158dae2eafb91e9da126f66bf9e111d6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 06:09:11 +0100 Subject: [PATCH 03/22] Add some missing backtraces. (#1193) --- candle-core/src/tensor.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 9dea62faa4..ce81d8aff0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1186,14 +1186,16 @@ impl Tensor { op: "scatter-add (self, src)", lhs: self.shape().clone(), rhs: source.shape().clone(), - })? + } + .bt())? } if indexes.dims() != source.dims() { Err(Error::ShapeMismatchBinaryOp { op: "scatter-add (indexes, src)", lhs: indexes.shape().clone(), rhs: source.shape().clone(), - })? + } + .bt())? } let storage = self.storage().scatter_add( self.layout(), @@ -1265,7 +1267,8 @@ impl Tensor { op: "slice-scatter (self, src)", lhs: self.shape().clone(), rhs: src.shape().clone(), - })? + } + .bt())? } let mut storage = self.device().zeros(self.shape(), self.dtype())?; self.storage() @@ -1299,7 +1302,8 @@ impl Tensor { op: "index-add (self, source)", lhs: self.shape().clone(), rhs: source.shape().clone(), - })? + } + .bt())? } // The number of element in indexes must match the dimension on which the add is // performed on the source tensor (and the index values from `indexes` are taken from @@ -1310,7 +1314,8 @@ impl Tensor { op: "index-add (ids, source))", lhs: indexes.shape().clone(), rhs: source.shape().clone(), - })? + } + .bt())? } let storage = self.storage().index_add( self.layout(), @@ -1358,7 +1363,8 @@ impl Tensor { op: "gather", lhs: self.shape().clone(), rhs: indexes.shape().clone(), - })? + } + .bt())? } let storage = self.storage() From 916619f70bfae089597ce421e19a3b2e85c2d27b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 14:08:29 +0100 Subject: [PATCH 04/22] Minor cleanup (#1194) * Add some missing backtraces. * Small cleanup. --- candle-wasm-examples/llama2-c/src/model.rs | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 3fedb1d365..7471938af1 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -57,20 +57,6 @@ impl Cache { } } -fn silu(xs: &Tensor) -> Result { - xs / (xs.neg()?.exp()? + 1.0)? -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) -} - -fn embedding(cfg: &Config, vb: VarBuilder) -> Result { - let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?; - Ok(Embedding::new(embeddings, cfg.dim)) -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -198,7 +184,7 @@ impl Mlp { } fn forward(&self, x: &Tensor) -> Result { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) } @@ -283,7 +269,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) From e2826e70b3725c53656f1ff76753472b29e1c5f7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 15:34:06 +0100 Subject: [PATCH 05/22] Add a quantized variant of llama2.c (#1197) * Add a quantized variant of llama2.c * Clippy fixes. --- candle-core/src/quantized/neon.rs | 26 +-- candle-core/src/quantized/simd128.rs | 4 - candle-examples/examples/llama2-c/main.rs | 60 +++++- candle-examples/examples/llama2-c/model.rs | 8 +- candle-examples/examples/llama2-c/qmodel.rs | 227 ++++++++++++++++++++ 5 files changed, 287 insertions(+), 38 deletions(-) create mode 100644 candle-examples/examples/llama2-c/qmodel.rs diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index fd4c138818..51bd3e66ef 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -94,28 +94,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") } let nb = n / QK8_0; - if nb % 2 != 0 { - crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") - } unsafe { let mut sumv0 = vdupq_n_f32(0.0f32); - let mut sumv1 = vdupq_n_f32(0.0f32); - for i in (0..nb).step_by(2) { + for i in 0..nb { let x0 = &xs[i]; - let x1 = &xs[i + 1]; let y0 = &ys[i]; - let y1 = &ys[i + 1]; let x0_0 = vld1q_s8(x0.qs.as_ptr()); let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); - let x1_0 = vld1q_s8(x1.qs.as_ptr()); - let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); // load y let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - let y1_0 = vld1q_s8(y1.qs.as_ptr()); - let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); // TODO dotprod once this is the intrinsics are. let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); @@ -123,28 +113,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0)); - let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1)); - let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0.d.to_f32() * y0.d.to_f32(), ); - sumv1 = vmlaq_n_f32( - sumv1, - vcvtq_f32_s32(vaddq_s32(p2, p3)), - x1.d.to_f32() * y1.d.to_f32(), - ); } - Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1)) + Ok(vaddvq_f32(sumv0)) } } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 687399c201..f256fdc204 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -61,10 +61,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> if n % QK8_0 != 0 { crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") } - let nb = n / QK8_0; - if nb % 2 != 0 { - crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") - } unsafe { let mut acc = f32x4_splat(0.0f32); for (x, y) in xs.iter().zip(ys.iter()) { diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index e752a494ac..77dbc6778e 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -7,6 +7,7 @@ extern crate accelerate_src; extern crate intel_mkl_src; mod model; +mod qmodel; mod training; mod weights; use clap::{Parser, Subcommand}; @@ -19,6 +20,7 @@ use std::io::Write; use tokenizers::Tokenizer; use model::{Config, Llama}; +use qmodel::QLlama; use weights::TransformerWeights; #[derive(Parser, Debug, Clone)] @@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> { Ok(()) } +enum Model { + Llama(Llama), + QLlama(QLlama), +} + +impl Model { + fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result { + match self { + Self::Llama(l) => Ok(l.forward(xs, pos)?), + Self::QLlama(l) => Ok(l.forward(xs, pos)?), + } + } +} + fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { use std::io::BufRead; @@ -241,24 +257,56 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let device = candle_examples::device(common_args.cpu)?; + let is_gguf = config_path.extension().map_or(false, |v| v == "gguf"); let is_safetensors = config_path .extension() .map_or(false, |v| v == "safetensors"); - let (vb, config) = if is_safetensors { + let (model, config) = if is_gguf { + let config = Config::tiny(); + let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let freq_cis_real = vb + .get( + (config.seq_len, config.head_size() / 2), + "rot.freq_cis_real", + )? + .dequantize(&candle::Device::Cpu)?; + let freq_cis_imag = vb + .get( + (config.seq_len, config.head_size() / 2), + "rot.freq_cis_imag", + )? + .dequantize(&candle::Device::Cpu)?; + + let fake_vb = candle_nn::VarBuilder::from_tensors( + [ + ("freq_cis_real".to_string(), freq_cis_real), + ("freq_cis_imag".to_string(), freq_cis_imag), + ] + .into_iter() + .collect(), + candle::DType::F32, + &candle::Device::Cpu, + ); + let cache = model::Cache::new(true, &config, fake_vb)?; + let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); + (model, config) + } else if is_safetensors { let config = Config::tiny(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); - (vb, config) + let cache = model::Cache::new(true, &config, vb.pp("rot"))?; + let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); + (model, config) } else { let mut file = std::fs::File::open(config_path)?; let config = Config::from_reader(&mut file)?; println!("{config:?}"); let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; - (vb, config) + let cache = model::Cache::new(true, &config, vb.pp("rot"))?; + let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); + (model, config) }; - let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, config)?; println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p); @@ -273,7 +321,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let start_gen = std::time::Instant::now(); for index in 0.. { - if tokens.len() >= model.config.seq_len { + if tokens.len() >= config.seq_len { break; } let context_size = if index > 0 { 1 } else { tokens.len() }; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 9b982dddb7..07a6e2f211 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -36,9 +36,9 @@ pub struct Cache { masks: Arc>>, pub use_kv_cache: bool, #[allow(clippy::type_complexity)] - kvs: Arc>>>, - cos: Tensor, - sin: Tensor, + pub kvs: Arc>>>, + pub cos: Tensor, + pub sin: Tensor, device: Device, } @@ -75,7 +75,7 @@ impl Cache { }) } - fn mask(&self, t: usize) -> Result { + pub fn mask(&self, t: usize) -> Result { let mut masks = self.masks.lock().unwrap(); if let Some(mask) = masks.get(&t) { Ok(mask.clone()) diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-examples/examples/llama2-c/qmodel.rs new file mode 100644 index 0000000000..07db146ebd --- /dev/null +++ b/candle-examples/examples/llama2-c/qmodel.rs @@ -0,0 +1,227 @@ +use super::model::{Cache, Config}; +use candle::{DType, IndexOp, Module, Result, Tensor, D}; +use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; +pub use candle_transformers::quantized_var_builder::VarBuilder; + +fn silu(xs: &Tensor) -> Result { + xs / (xs.neg()?.exp()? + 1.0)? +} + +struct CausalSelfAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + n_head: usize, + n_key_value_head: usize, + head_dim: usize, + cache: Cache, +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let (b_sz, seq_len, h, n_embd) = x.dims4()?; + let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; + let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; + let cos = cos.unsqueeze(1)?; + let sin = sin.unsqueeze(1)?; + let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; + Ok(rope) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?; + let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; + let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + if self.cache.use_kv_cache { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let q = q.transpose(1, 2)?.contiguous()?; + let k = k.transpose(1, 2)?.contiguous()?; + let v = v.transpose(1, 2)?.contiguous()?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + let n_rep = self.n_head / self.n_key_value_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?; + let x = x + .unsqueeze(3)? + .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))? + .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?; + Ok(x) + } + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let size_in = cfg.dim; + let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; + let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + n_head: cfg.n_heads, + n_key_value_head: cfg.n_kv_heads, + head_dim: cfg.dim / cfg.n_heads, + cache: cache.clone(), + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: Linear, + c_fc2: Linear, + c_proj: Linear, +} + +impl Mlp { + fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { + Self { + c_fc1, + c_fc2, + c_proj, + } + } + + fn forward(&self, x: &Tensor) -> Result { + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let h_size = cfg.dim; + let i_size = cfg.hidden_dim; + let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +impl Block { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { + rms_1, + attn, + rms_2, + mlp, + } + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = + RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +pub struct QLlama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, + pub config: Config, +} + +impl QLlama { + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, _seq_len) = x.dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result { + let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; + let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.n_layers) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap()) + .collect(); + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + config: cfg, + }) + } +} From b3181455d5bbebdcc358a48fd4d1e5ed38d78198 Mon Sep 17 00:00:00 2001 From: jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Fri, 27 Oct 2023 22:56:50 +0800 Subject: [PATCH 06/22] Add fuse-conv-bn method for Conv2d (#1196) * Add fuse-conv-bn method for Conv2d * no unwrap * run rustfmp and clippy --- candle-examples/examples/yolo-v8/model.rs | 9 ++------- candle-nn/src/batch_norm.rs | 4 ++++ candle-nn/src/conv.rs | 21 +++++++++++++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index bf48fd8419..cecd4ce6c4 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -1,7 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{ - batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder, -}; +use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder}; #[derive(Clone, Copy, PartialEq, Debug)] pub struct Multiples { @@ -76,7 +74,6 @@ impl Module for Upsample { #[derive(Debug)] struct ConvBlock { conv: Conv2d, - bn: BatchNorm, span: tracing::Span, } @@ -96,11 +93,10 @@ impl ConvBlock { groups: 1, dilation: 1, }; - let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; + let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?; Ok(Self { conv, - bn, span: tracing::span!(tracing::Level::TRACE, "conv-block"), }) } @@ -110,7 +106,6 @@ impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let xs = self.conv.forward(xs)?; - let xs = self.bn.forward(&xs)?; candle_nn::ops::silu(&xs) } } diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 05904859a7..8cfc6740b4 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -109,6 +109,10 @@ impl BatchNorm { &self.running_var } + pub fn eps(&self) -> f64 { + self.eps + } + pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> { self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) } diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 89e9f42d4c..7c0bf841dc 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -1,4 +1,5 @@ //! Convolution Layers. +use crate::BatchNorm; use candle::{Result, Tensor}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -115,6 +116,26 @@ impl Conv2d { pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } + + pub fn absorb_bn(&self, bn: &BatchNorm) -> Result { + if let Some((w_bn, b_bn)) = bn.weight_and_bias() { + let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?; + let weight = self + .weight() + .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?; + let bias = match &self.bias { + None => b_bn.sub(&(std_.mul(bn.running_mean())?))?, + Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?, + }; + Ok(Self { + weight, + bias: Some(bias), + config: self.config, + }) + } else { + candle::bail!("batch norm does not have weight_and_bias") + } + } } impl crate::Module for Conv2d { From 85bea43e5b088b94612b0fd7ed8f09261dc79d52 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 17:59:19 +0200 Subject: [PATCH 07/22] Make the whisper model cloneable (#1200) * Add a quantized variant of llama2.c * Clippy fixes. * Make the whisper model cloneable. --- candle-transformers/src/models/whisper/model.rs | 7 ++++++- candle-transformers/src/models/whisper/quantized_model.rs | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 2a58afafe9..6078944c1a 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -9,7 +9,7 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result Result { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +#[derive(Debug, Clone)] struct MultiHeadAttention { query: Linear, key: Linear, @@ -162,6 +163,7 @@ impl MultiHeadAttention { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +#[derive(Debug, Clone)] struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, @@ -241,6 +243,7 @@ fn sinusoids(length: usize, channels: usize) -> Result { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +#[derive(Debug, Clone)] pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, @@ -316,6 +319,7 @@ impl AudioEncoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +#[derive(Debug, Clone)] pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, @@ -380,6 +384,7 @@ impl TextDecoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +#[derive(Debug, Clone)] pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index f0aead49aa..43ea4177d5 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -19,6 +19,7 @@ fn conv1d( } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +#[derive(Debug, Clone)] struct MultiHeadAttention { query: Linear, key: Linear, @@ -128,6 +129,7 @@ impl MultiHeadAttention { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +#[derive(Debug, Clone)] struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, @@ -206,6 +208,7 @@ fn sinusoids(length: usize, channels: usize) -> Result { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +#[derive(Debug, Clone)] pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, @@ -281,6 +284,7 @@ impl AudioEncoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +#[derive(Debug, Clone)] pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, @@ -347,6 +351,7 @@ impl TextDecoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +#[derive(Debug, Clone)] pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, From c8face3f95a9c57b4714cd95dc69237533558c25 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 21:51:16 +0200 Subject: [PATCH 08/22] Add the relu2 and relu6 activations. (#1201) --- candle-nn/src/activation.rs | 4 ++ candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/persimmon.rs | 56 +++++++++++++++++++++ 3 files changed, 61 insertions(+) create mode 100644 candle-transformers/src/models/persimmon.rs diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index ddc211a732..52ceba78c7 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -9,6 +9,8 @@ pub enum Activation { #[serde(rename = "gated-gelu")] NewGelu, Relu, + Relu2, + Relu6, Silu, Sigmoid, Elu(f64), @@ -22,6 +24,8 @@ impl super::Module for Activation { // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 Self::NewGelu => xs.gelu(), Self::Relu => xs.relu(), + Self::Relu2 => xs.relu()?.sqr(), + Self::Relu6 => xs.clamp(0f32, 6f32), Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), &Self::Elu(alpha) => xs.elu(alpha), diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4e7c8bf002..f722e93b04 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -11,6 +11,7 @@ pub mod llama; pub mod mistral; pub mod mixformer; pub mod mpt; +pub mod persimmon; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; diff --git a/candle-transformers/src/models/persimmon.rs b/candle-transformers/src/models/persimmon.rs new file mode 100644 index 0000000000..afee7c83ee --- /dev/null +++ b/candle-transformers/src/models/persimmon.rs @@ -0,0 +1,56 @@ +use candle::DType; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PositionEmbeddingType { + Absolute, + Alibi, +} + +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub hidden_act: candle_nn::Activation, + pub max_position_embeddings: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub rms_norm_eps: f64, + pub use_cache: bool, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub qk_layernorm: bool, + pub partial_rotary_factor: f64, +} + +impl Config { + pub fn base_8b() -> Self { + // https://huggingface.co/adept/persimmon-8b-base/blob/main/config.json + Self { + hidden_act: candle_nn::Activation::Relu, + hidden_size: 4096, + initializer_range: 0.02, + intermediate_size: 16384, + layer_norm_eps: 1e-05, + max_position_embeddings: 16384, + num_attention_heads: 64, + num_hidden_layers: 36, + num_key_value_heads: 64, + qk_layernorm: true, + rms_norm_eps: 1e-06, + rope_theta: 25000.0, + tie_word_embeddings: false, + use_cache: true, + vocab_size: 262144, + partial_rotary_factor: 0.5, + } + } +} From ef33df7ae2b94e2b911b61f3765d6826726614e7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 08:23:59 +0200 Subject: [PATCH 09/22] No need for the even constraint on vecdot-q40-q80. (#1202) --- candle-core/src/quantized/avx.rs | 5 ----- candle-core/src/quantized/k_quants.rs | 5 ----- candle-core/src/quantized/neon.rs | 29 ++------------------------- candle-core/src/quantized/simd128.rs | 4 ---- 4 files changed, 2 insertions(+), 41 deletions(-) diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index d4b05bb0f4..5c3ac822e2 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 { #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; - let nb = n / qk; if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } - unsafe { let mut acc = _mm256_setzero_ps(); for (x, y) in xs.iter().zip(ys.iter()) { diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index b140131e4b..d16289e6ff 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 { fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { let qk = QK8_0; - let nb = n / qk; if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } - // Generic implementation. let mut sumf = 0f32; for (xs, ys) in xs.iter().zip(ys.iter()) { diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 51bd3e66ef..3cb5622960 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } unsafe { let mut sumv0 = vdupq_n_f32(0.0f32); - let mut sumv1 = vdupq_n_f32(0.0f32); - for i in (0..nb).step_by(2) { + for i in 0..nb { let x0 = &xs[i]; - let x1 = &xs[i + 1]; let y0 = &ys[i]; - let y1 = &ys[i + 1]; let m4b = vdupq_n_u8(0x0F); let s8b = vdupq_n_s8(0x8); let v0_0 = vld1q_u8(x0.qs.as_ptr()); - let v0_1 = vld1q_u8(x1.qs.as_ptr()); // 4-bit -> 8-bit let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); - let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); // sub 8 let v0_0ls = vsubq_s8(v0_0l, s8b); let v0_0hs = vsubq_s8(v0_0h, s8b); - let v0_1ls = vsubq_s8(v0_1l, s8b); - let v0_1hs = vsubq_s8(v0_1h, s8b); // load y let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - let v1_1l = vld1q_s8(y1.qs.as_ptr()); - let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16)); // TODO: Support dotprod when it's available outside of nightly. let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); @@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l)); - let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h)); - let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0.d.to_f32() * y0.d.to_f32(), ); - sumv1 = vmlaq_n_f32( - sumv1, - vcvtq_f32_s32(vaddq_s32(pl1, ph1)), - x1.d.to_f32() * y1.d.to_f32(), - ); } - Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1)) + Ok(vaddvq_f32(sumv0)) } } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index f256fdc204..1c8c0f2068 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - let nb = n / QK8_0; - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } unsafe { let mut acc = f32x4_splat(0.0f32); for (x, y) in xs.iter().zip(ys.iter()) { From 612f5b81561150ca6651368c245ac2065c04159a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 08:43:08 +0200 Subject: [PATCH 10/22] Make more models cloneable. (#1203) --- .../src/models/quantized_stable_lm.rs | 8 +++---- .../src/models/quantized_t5.rs | 22 +++++++++---------- candle-transformers/src/models/t5.rs | 22 +++++++++---------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index d117e4b317..94c962014f 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -7,7 +7,7 @@ use std::sync::Arc; pub use crate::models::stable_lm::Config; use crate::models::stable_lm::RotaryEmbedding; -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, @@ -43,7 +43,7 @@ impl Module for MLP { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, @@ -168,7 +168,7 @@ impl Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -213,7 +213,7 @@ impl DecoderLayer { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { embed_tokens: Embedding, layers: Vec, diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 1426df39d1..4e5bd81a71 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -93,7 +93,7 @@ impl Default for Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerNorm { weight: Tensor, variance_epsilon: f64, @@ -125,7 +125,7 @@ impl Module for T5LayerNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseActDense { wi: QMatMul, wo: QMatMul, @@ -156,7 +156,7 @@ impl Module for T5DenseActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseGatedActDense { wi_0: QMatMul, wi_1: QMatMul, @@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerFF { dense_act: Option, gated_dense_act: Option, @@ -236,7 +236,7 @@ impl Module for T5LayerFF { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Attention { q: QMatMul, k: QMatMul, @@ -431,7 +431,7 @@ impl T5Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, @@ -470,7 +470,7 @@ impl T5LayerSelfAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerCrossAttention { cross_attention: T5Attention, layer_norm: T5LayerNorm, @@ -512,7 +512,7 @@ impl T5LayerCrossAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Block { self_attn: T5LayerSelfAttention, cross_attn: Option, @@ -583,7 +583,7 @@ impl T5Block { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Stack { block: Vec, shared: Arc, @@ -633,7 +633,7 @@ impl T5Stack { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5EncoderModel { encoder: T5Stack, device: Device, @@ -666,7 +666,7 @@ impl T5EncoderModel { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5ForConditionalGeneration { encoder: T5Stack, decoder: T5Stack, diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 9b3d97b8bd..1101d0013f 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -118,7 +118,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerNorm { weight: Tensor, variance_epsilon: f64, @@ -150,7 +150,7 @@ impl Module for T5LayerNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseActDense { wi: Linear, wo: Linear, @@ -181,7 +181,7 @@ impl Module for T5DenseActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseGatedActDense { wi_0: Linear, wi_1: Linear, @@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerFF { dense_act: Option, gated_dense_act: Option, @@ -261,7 +261,7 @@ impl Module for T5LayerFF { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Attention { q: Linear, k: Linear, @@ -456,7 +456,7 @@ impl T5Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, @@ -495,7 +495,7 @@ impl T5LayerSelfAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerCrossAttention { cross_attention: T5Attention, layer_norm: T5LayerNorm, @@ -537,7 +537,7 @@ impl T5LayerCrossAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Block { self_attn: T5LayerSelfAttention, cross_attn: Option, @@ -608,7 +608,7 @@ impl T5Block { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Stack { block: Vec, shared: Arc, @@ -658,7 +658,7 @@ impl T5Stack { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5EncoderModel { encoder: T5Stack, device: Device, @@ -691,7 +691,7 @@ impl T5EncoderModel { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5ForConditionalGeneration { encoder: T5Stack, decoder: T5Stack, From 95a857cf57c56a34ecdaae5372f2a13ebd900001 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 17:51:19 +0200 Subject: [PATCH 11/22] Move the llama2-c model in transformers. (#1205) --- candle-examples/examples/llama2-c/main.rs | 6 +++--- candle-transformers/Cargo.toml | 1 + .../model.rs => candle-transformers/src/models/llama2_c.rs | 0 .../src/models/llama2_c_weights.rs | 5 ++--- candle-transformers/src/models/mod.rs | 3 +++ .../src/models/quantized_llama2_c.rs | 6 +++--- 6 files changed, 12 insertions(+), 9 deletions(-) rename candle-examples/examples/llama2-c/model.rs => candle-transformers/src/models/llama2_c.rs (100%) rename candle-examples/examples/llama2-c/weights.rs => candle-transformers/src/models/llama2_c_weights.rs (98%) rename candle-examples/examples/llama2-c/qmodel.rs => candle-transformers/src/models/quantized_llama2_c.rs (97%) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 77dbc6778e..a3f01ae2e9 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -6,10 +6,10 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -mod model; -mod qmodel; +use candle_transformers::models::llama2_c as model; +use candle_transformers::models::llama2_c_weights as weights; +use candle_transformers::models::quantized_llama2_c as qmodel; mod training; -mod weights; use clap::{Parser, Subcommand}; use anyhow::{Error as E, Result}; diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 5af7e55d7c..e7290be639 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } +byteorder = { workspace = true } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true } candle-nn = { path = "../candle-nn", version = "0.3.0" } diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-transformers/src/models/llama2_c.rs similarity index 100% rename from candle-examples/examples/llama2-c/model.rs rename to candle-transformers/src/models/llama2_c.rs diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-transformers/src/models/llama2_c_weights.rs similarity index 98% rename from candle-examples/examples/llama2-c/weights.rs rename to candle-transformers/src/models/llama2_c_weights.rs index b78418ce35..e5a8bb8806 100644 --- a/candle-examples/examples/llama2-c/weights.rs +++ b/candle-transformers/src/models/llama2_c_weights.rs @@ -1,9 +1,8 @@ -use anyhow::Result; use byteorder::{LittleEndian, ReadBytesExt}; -use candle::{DType, Device, IndexOp, Shape, Tensor}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle_nn::VarBuilder; -use crate::model::Config; +use super::llama2_c::Config; pub struct TransformerWeights { // token embedding table diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f722e93b04..c59bd880cf 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,8 @@ pub mod efficientnet; pub mod falcon; pub mod jina_bert; pub mod llama; +pub mod llama2_c; +pub mod llama2_c_weights; pub mod mistral; pub mod mixformer; pub mod mpt; @@ -15,6 +17,7 @@ pub mod persimmon; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; +pub mod quantized_llama2_c; pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_mpt; diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-transformers/src/models/quantized_llama2_c.rs similarity index 97% rename from candle-examples/examples/llama2-c/qmodel.rs rename to candle-transformers/src/models/quantized_llama2_c.rs index 07db146ebd..68ebee0da4 100644 --- a/candle-examples/examples/llama2-c/qmodel.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -1,7 +1,7 @@ -use super::model::{Cache, Config}; +use super::llama2_c::{Cache, Config}; +use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; +pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, IndexOp, Module, Result, Tensor, D}; -use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; -pub use candle_transformers::quantized_var_builder::VarBuilder; fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? From 012ae0090e70da67987a0308ef18587e9e8a8e44 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 20:00:39 +0200 Subject: [PATCH 12/22] Infer the config for llama2-c. (#1208) --- candle-examples/examples/llama2-c/main.rs | 14 ++++++- candle-examples/examples/llama2-c/training.rs | 2 +- candle-transformers/src/models/llama2_c.rs | 41 ++++++++++++++++++- .../src/quantized_var_builder.rs | 10 +++++ 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index a3f01ae2e9..0ceb27af7e 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let config = Config::tiny(); let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let (_vocab_size, dim) = vb + .get_no_shape("model.embed_tokens.weight")? + .shape() + .dims2()?; + let config = match dim { + 64 => Config::tiny_260k(), + 288 => Config::tiny_15m(), + 512 => Config::tiny_42m(), + 768 => Config::tiny_110m(), + _ => anyhow::bail!("no config for dim {dim}"), + }; let freq_cis_real = vb .get( (config.seq_len, config.head_size() / 2), @@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); (model, config) } else if is_safetensors { - let config = Config::tiny(); + let config = Config::tiny_15m(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let cache = model::Cache::new(true, &config, vb.pp("rot"))?; diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 150a327239..b2aa0889fc 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { ); let varmap = candle_nn::VarMap::new(); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); - let config = Config::tiny(); + let config = Config::tiny_15m(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 07a6e2f211..753770fb75 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -17,7 +17,20 @@ pub struct Config { } impl Config { - pub fn tiny() -> Self { + pub fn tiny_260k() -> Self { + Self { + dim: 64, + hidden_dim: 768, + n_layers: 5, + n_heads: 8, + n_kv_heads: 4, + vocab_size: 32000, + seq_len: 512, + norm_eps: 1e-5, + } + } + + pub fn tiny_15m() -> Self { Self { dim: 288, hidden_dim: 768, @@ -29,6 +42,32 @@ impl Config { norm_eps: 1e-5, } } + + pub fn tiny_42m() -> Self { + Self { + dim: 512, + hidden_dim: 768, + n_layers: 8, + n_heads: 8, + n_kv_heads: 8, + vocab_size: 32000, + seq_len: 1024, + norm_eps: 1e-5, + } + } + + pub fn tiny_110m() -> Self { + Self { + dim: 768, + hidden_dim: 768, + n_layers: 12, + n_heads: 12, + n_kv_heads: 12, + vocab_size: 32000, + seq_len: 1024, + norm_eps: 1e-5, + } + } } #[derive(Clone)] diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 259496d620..810802e8d7 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -77,6 +77,16 @@ impl VarBuilder { } } + pub fn get_no_shape(&self, name: &str) -> Result> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => Ok(qtensor.clone()), + } + } + pub fn device(&self) -> &Device { &self.device } From 498c50348ce13456d683c987ad9aef319a45eb4a Mon Sep 17 00:00:00 2001 From: Travis Hammond Date: Sat, 28 Oct 2023 20:53:34 +0200 Subject: [PATCH 13/22] Add DDPG and fix Gym wrapper (#1207) * Fix Gym wrapper - It was returning things in the wrong order - Gym now differentiates between terminated and truncated * Add DDPG * Apply fixes * Remove Result annotations * Also remove Vec annotation * rustfmt * Various small improvements (avoid cloning, mutability, get clippy to pass, ...) --------- Co-authored-by: Travis Hammond Co-authored-by: Laurent --- .../examples/reinforcement-learning/ddpg.rs | 451 ++++++++++++++++++ .../reinforcement-learning/gym_env.rs | 38 +- .../examples/reinforcement-learning/main.rs | 85 +++- 3 files changed, 549 insertions(+), 25 deletions(-) create mode 100644 candle-examples/examples/reinforcement-learning/ddpg.rs diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs new file mode 100644 index 0000000000..c6d72fed4c --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -0,0 +1,451 @@ +use std::collections::VecDeque; +use std::fmt::Display; + +use candle::{DType, Device, Error, Module, Result, Tensor, Var}; +use candle_nn::{ + func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential, + VarBuilder, VarMap, +}; +use rand::{distributions::Uniform, thread_rng, Rng}; + +pub struct OuNoise { + mu: f64, + theta: f64, + sigma: f64, + state: Tensor, +} +impl OuNoise { + pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result { + Ok(Self { + mu, + theta, + sigma, + state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?, + }) + } + + pub fn sample(&mut self) -> Result { + let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?; + let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?; + self.state = (&self.state + dx)?; + Ok(self.state.clone()) + } +} + +#[derive(Clone)] +struct Transition { + state: Tensor, + action: Tensor, + reward: Tensor, + next_state: Tensor, + terminated: bool, + truncated: bool, +} +impl Transition { + fn new( + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) -> Self { + Self { + state: state.clone(), + action: action.clone(), + reward: reward.clone(), + next_state: next_state.clone(), + terminated, + truncated, + } + } +} + +pub struct ReplayBuffer { + buffer: VecDeque, + capacity: usize, + size: usize, +} +impl ReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: VecDeque::with_capacity(capacity), + capacity, + size: 0, + } + } + + pub fn push( + &mut self, + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) { + if self.size == self.capacity { + self.buffer.pop_front(); + } else { + self.size += 1; + } + self.buffer.push_back(Transition::new( + state, action, reward, next_state, terminated, truncated, + )); + } + + #[allow(clippy::type_complexity)] + pub fn random_batch( + &self, + batch_size: usize, + ) -> Result, Vec)>> { + if self.size < batch_size { + Ok(None) + } else { + let transitions: Vec<&Transition> = thread_rng() + .sample_iter(Uniform::from(0..self.size)) + .take(batch_size) + .map(|i| self.buffer.get(i).unwrap()) + .collect(); + + let states: Vec = transitions + .iter() + .map(|t| t.state.unsqueeze(0)) + .collect::>()?; + let actions: Vec = transitions + .iter() + .map(|t| t.action.unsqueeze(0)) + .collect::>()?; + let rewards: Vec = transitions + .iter() + .map(|t| t.reward.unsqueeze(0)) + .collect::>()?; + let next_states: Vec = transitions + .iter() + .map(|t| t.next_state.unsqueeze(0)) + .collect::>()?; + let terminateds: Vec = transitions.iter().map(|t| t.terminated).collect(); + let truncateds: Vec = transitions.iter().map(|t| t.truncated).collect(); + + Ok(Some(( + Tensor::cat(&states, 0)?, + Tensor::cat(&actions, 0)?, + Tensor::cat(&rewards, 0)?, + Tensor::cat(&next_states, 0)?, + terminateds, + truncateds, + ))) + } + } +} + +fn track( + varmap: &mut VarMap, + vb: &VarBuilder, + target_prefix: &str, + network_prefix: &str, + dims: &[(usize, usize)], + tau: f64, +) -> Result<()> { + for (i, &(in_dim, out_dim)) in dims.iter().enumerate() { + let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?; + let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?; + varmap.set_one( + format!("{target_prefix}-fc{i}.weight"), + ((tau * network_w)? + ((1.0 - tau) * target_w)?)?, + )?; + + let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?; + let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?; + varmap.set_one( + format!("{target_prefix}-fc{i}.bias"), + ((tau * network_b)? + ((1.0 - tau) * target_b)?)?, + )?; + } + Ok(()) +} + +struct Actor<'a> { + varmap: VarMap, + vb: VarBuilder<'a>, + network: Sequential, + target_network: Sequential, + size_state: usize, + size_action: usize, + dims: Vec<(usize, usize)>, +} + +impl Actor<'_> { + fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result { + let mut varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, dtype, device); + + let dims = vec![(size_state, 400), (400, 300), (300, size_action)]; + + let make_network = |prefix: &str| { + let seq = seq() + .add(linear( + dims[0].0, + dims[0].1, + vb.pp(format!("{prefix}-fc0")), + )?) + .add(Activation::Relu) + .add(linear( + dims[1].0, + dims[1].1, + vb.pp(format!("{prefix}-fc1")), + )?) + .add(Activation::Relu) + .add(linear( + dims[2].0, + dims[2].1, + vb.pp(format!("{prefix}-fc2")), + )?) + .add(func(|xs| xs.tanh())); + Ok::(seq) + }; + + let network = make_network("actor")?; + let target_network = make_network("target-actor")?; + + // this sets the two networks to be equal to each other using tau = 1.0 + track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0); + + Ok(Self { + varmap, + vb, + network, + target_network, + size_state, + size_action, + dims, + }) + } + + fn forward(&self, state: &Tensor) -> Result { + self.network.forward(state) + } + + fn target_forward(&self, state: &Tensor) -> Result { + self.target_network.forward(state) + } + + fn track(&mut self, tau: f64) -> Result<()> { + track( + &mut self.varmap, + &self.vb, + "target-actor", + "actor", + &self.dims, + tau, + ) + } +} + +struct Critic<'a> { + varmap: VarMap, + vb: VarBuilder<'a>, + network: Sequential, + target_network: Sequential, + size_state: usize, + size_action: usize, + dims: Vec<(usize, usize)>, +} + +impl Critic<'_> { + fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result { + let mut varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, dtype, device); + + let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)]; + + let make_network = |prefix: &str| { + let seq = seq() + .add(linear( + dims[0].0, + dims[0].1, + vb.pp(format!("{prefix}-fc0")), + )?) + .add(Activation::Relu) + .add(linear( + dims[1].0, + dims[1].1, + vb.pp(format!("{prefix}-fc1")), + )?) + .add(Activation::Relu) + .add(linear( + dims[2].0, + dims[2].1, + vb.pp(format!("{prefix}-fc2")), + )?); + Ok::(seq) + }; + + let network = make_network("critic")?; + let target_network = make_network("target-critic")?; + + // this sets the two networks to be equal to each other using tau = 1.0 + track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0); + + Ok(Self { + varmap, + vb, + network, + target_network, + size_state, + size_action, + dims, + }) + } + + fn forward(&self, state: &Tensor, action: &Tensor) -> Result { + let xs = Tensor::cat(&[action, state], 1)?; + self.network.forward(&xs) + } + + fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result { + let xs = Tensor::cat(&[action, state], 1)?; + self.target_network.forward(&xs) + } + + fn track(&mut self, tau: f64) -> Result<()> { + track( + &mut self.varmap, + &self.vb, + "target-critic", + "critic", + &self.dims, + tau, + ) + } +} + +#[allow(clippy::upper_case_acronyms)] +pub struct DDPG<'a> { + actor: Actor<'a>, + actor_optim: AdamW, + critic: Critic<'a>, + critic_optim: AdamW, + gamma: f64, + tau: f64, + replay_buffer: ReplayBuffer, + ou_noise: OuNoise, + + size_state: usize, + size_action: usize, + pub train: bool, +} + +impl DDPG<'_> { + #[allow(clippy::too_many_arguments)] + pub fn new( + device: &Device, + size_state: usize, + size_action: usize, + train: bool, + actor_lr: f64, + critic_lr: f64, + gamma: f64, + tau: f64, + buffer_capacity: usize, + ou_noise: OuNoise, + ) -> Result { + let filter_by_prefix = |varmap: &VarMap, prefix: &str| { + varmap + .data() + .lock() + .unwrap() + .iter() + .filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone())) + .collect::>() + }; + + let actor = Actor::new(device, DType::F32, size_state, size_action)?; + let actor_optim = AdamW::new( + filter_by_prefix(&actor.varmap, "actor"), + ParamsAdamW { + lr: actor_lr, + ..Default::default() + }, + )?; + + let critic = Critic::new(device, DType::F32, size_state, size_action)?; + let critic_optim = AdamW::new( + filter_by_prefix(&critic.varmap, "critic"), + ParamsAdamW { + lr: critic_lr, + ..Default::default() + }, + )?; + + Ok(Self { + actor, + actor_optim, + critic, + critic_optim, + gamma, + tau, + replay_buffer: ReplayBuffer::new(buffer_capacity), + ou_noise, + size_state, + size_action, + train, + }) + } + + pub fn remember( + &mut self, + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) { + self.replay_buffer + .push(state, action, reward, next_state, terminated, truncated) + } + + pub fn actions(&mut self, state: &Tensor) -> Result { + let actions = self + .actor + .forward(&state.detach()?.unsqueeze(0)?)? + .squeeze(0)?; + let actions = if self.train { + (actions + self.ou_noise.sample()?)? + } else { + actions + }; + actions.squeeze(0)?.to_scalar::() + } + + pub fn train(&mut self, batch_size: usize) -> Result<()> { + let (states, actions, rewards, next_states, _, _) = + match self.replay_buffer.random_batch(batch_size)? { + Some(v) => v, + _ => return Ok(()), + }; + + let q_target = self + .critic + .target_forward(&next_states, &self.actor.target_forward(&next_states)?)?; + let q_target = (rewards + (self.gamma * q_target)?.detach())?; + let q = self.critic.forward(&states, &actions)?; + let diff = (q_target - q)?; + + let critic_loss = diff.sqr()?.mean_all()?; + self.critic_optim.backward_step(&critic_loss)?; + + let actor_loss = self + .critic + .forward(&states, &self.actor.forward(&states)?)? + .mean_all()? + .neg()?; + self.actor_optim.backward_step(&actor_loss)?; + + self.critic.track(self.tau)?; + self.actor.track(self.tau)?; + + Ok(()) + } +} diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs index b98be6bc86..8868c1884d 100644 --- a/candle-examples/examples/reinforcement-learning/gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -7,20 +7,22 @@ use pyo3::types::PyDict; /// The return value for a step. #[derive(Debug)] pub struct Step { - pub obs: Tensor, + pub state: Tensor, pub action: A, pub reward: f64, - pub is_done: bool, + pub terminated: bool, + pub truncated: bool, } impl Step { /// Returns a copy of this step changing the observation tensor. - pub fn copy_with_obs(&self, obs: &Tensor) -> Step { + pub fn copy_with_obs(&self, state: &Tensor) -> Step { Step { - obs: obs.clone(), + state: state.clone(), action: self.action, reward: self.reward, - is_done: self.is_done, + terminated: self.terminated, + truncated: self.truncated, } } } @@ -63,14 +65,14 @@ impl GymEnv { /// Resets the environment, returning the observation tensor. pub fn reset(&self, seed: u64) -> Result { - let obs: Vec = Python::with_gil(|py| { + let state: Vec = Python::with_gil(|py| { let kwargs = PyDict::new(py); kwargs.set_item("seed", seed)?; - let obs = self.env.call_method(py, "reset", (), Some(kwargs))?; - obs.as_ref(py).get_item(0)?.extract() + let state = self.env.call_method(py, "reset", (), Some(kwargs))?; + state.as_ref(py).get_item(0)?.extract() }) .map_err(w)?; - Tensor::new(obs, &Device::Cpu) + Tensor::new(state, &Device::Cpu) } /// Applies an environment step using the specified action. @@ -78,21 +80,23 @@ impl GymEnv { &self, action: A, ) -> Result> { - let (obs, reward, is_done) = Python::with_gil(|py| { + let (state, reward, terminated, truncated) = Python::with_gil(|py| { let step = self.env.call_method(py, "step", (action.clone(),), None)?; let step = step.as_ref(py); - let obs: Vec = step.get_item(0)?.extract()?; + let state: Vec = step.get_item(0)?.extract()?; let reward: f64 = step.get_item(1)?.extract()?; - let is_done: bool = step.get_item(2)?.extract()?; - Ok((obs, reward, is_done)) + let terminated: bool = step.get_item(2)?.extract()?; + let truncated: bool = step.get_item(3)?.extract()?; + Ok((state, reward, terminated, truncated)) }) .map_err(w)?; - let obs = Tensor::new(obs, &Device::Cpu)?; + let state = Tensor::new(state, &Device::Cpu)?; Ok(Step { - obs, - reward, - is_done, + state, action, + reward, + terminated, + truncated, }) } diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs index f16f042e9e..96d7102d9f 100644 --- a/candle-examples/examples/reinforcement-learning/main.rs +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -9,14 +9,34 @@ extern crate accelerate_src; mod gym_env; mod vec_gym_env; -use candle::Result; +mod ddpg; + +use candle::{Device, Result, Tensor}; use clap::Parser; use rand::Rng; +// The impact of the q value of the next state on the current state's q value. +const GAMMA: f64 = 0.99; +// The weight for updating the target networks. +const TAU: f64 = 0.005; +// The capacity of the replay buffer used for sampling training data. +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +// The training batch size for each training iteration. +const TRAINING_BATCH_SIZE: usize = 100; // The total number of episodes. const MAX_EPISODES: usize = 100; // The maximum length of an episode. const EPISODE_LENGTH: usize = 200; +// The number of training iterations after one episode finishes. +const TRAINING_ITERATIONS: usize = 200; + +// Ornstein-Uhlenbeck process parameters. +const MU: f64 = 0.0; +const THETA: f64 = 0.15; +const SIGMA: f64 = 0.1; + +const ACTOR_LEARNING_RATE: f64 = 1e-4; +const CRITIC_LEARNING_RATE: f64 = 1e-3; #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] @@ -48,28 +68,77 @@ fn main() -> Result<()> { println!("action space: {}", env.action_space()); println!("observation space: {:?}", env.observation_space()); - let _num_obs = env.observation_space().iter().product::(); - let _num_actions = env.action_space(); + let size_state = env.observation_space().iter().product::(); + let size_action = env.action_space(); + + let mut agent = ddpg::DDPG::new( + &Device::Cpu, + size_state, + size_action, + true, + ACTOR_LEARNING_RATE, + CRITIC_LEARNING_RATE, + GAMMA, + TAU, + REPLAY_BUFFER_CAPACITY, + ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?, + )?; let mut rng = rand::thread_rng(); for episode in 0..MAX_EPISODES { - let mut obs = env.reset(episode as u64)?; + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::())?; let mut total_reward = 0.0; for _ in 0..EPISODE_LENGTH { - let actions = rng.gen_range(-2.0..2.0); + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); - let step = env.step(vec![actions])?; + let step = env.step(vec![action])?; total_reward += step.reward; - if step.is_done { + agent.remember( + &state, + &Tensor::new(vec![action], &Device::Cpu)?, + &Tensor::new(vec![step.reward as f32], &Device::Cpu)?, + &step.state, + step.terminated, + step.truncated, + ); + + if step.terminated || step.truncated { break; } - obs = step.obs; + state = step.state; } println!("episode {episode} with total reward of {total_reward}"); + + for _ in 0..TRAINING_ITERATIONS { + agent.train(TRAINING_BATCH_SIZE)?; + } + } + + println!("Testing..."); + agent.train = false; + for episode in 0..10 { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::())?; + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + println!("episode {episode} with total reward of {total_reward}"); } Ok(()) } From dece37c6f4d9c5a52caf59a003afa6ba33034fe3 Mon Sep 17 00:00:00 2001 From: drbh Date: Sun, 29 Oct 2023 02:10:23 -0400 Subject: [PATCH 14/22] feat: implement VGG13, VGG16 and VGG19 (#1211) * feat: implement VGG13, VGG16 and VGG19 * Cosmetic fixes. * More cosmetic tweaks + avoid re-loading the weights on each final layer. --------- Co-authored-by: Laurent --- candle-examples/examples/vgg/README.md | 13 ++ candle-examples/examples/vgg/main.rs | 77 ++++++++ candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/vgg.rs | 254 +++++++++++++++++++++++++ 4 files changed, 345 insertions(+) create mode 100644 candle-examples/examples/vgg/README.md create mode 100644 candle-examples/examples/vgg/main.rs create mode 100644 candle-transformers/src/models/vgg.rs diff --git a/candle-examples/examples/vgg/README.md b/candle-examples/examples/vgg/README.md new file mode 100644 index 0000000000..473038e805 --- /dev/null +++ b/candle-examples/examples/vgg/README.md @@ -0,0 +1,13 @@ +## VGG Model Implementation + +This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library. + +The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image. + +You can run the example with the following command: + +```bash +cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13 +``` + +In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19). diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs new file mode 100644 index 0000000000..e01fa8e8b5 --- /dev/null +++ b/candle-examples/examples/vgg/main.rs @@ -0,0 +1,77 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::vgg::{Models, Vgg}; +use clap::{Parser, ValueEnum}; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Vgg13, + Vgg16, + Vgg19, +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Variant of the model to use. + #[arg(value_enum, long, default_value_t = Which::Vgg13)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let image = candle_examples::imagenet::load_image224(args.image)?; + + println!("loaded image {image:?}"); + + let api = hf_hub::api::sync::Api::new()?; + let repo = match args.which { + Which::Vgg13 => "timm/vgg13.tv_in1k", + Which::Vgg16 => "timm/vgg16.tv_in1k", + Which::Vgg19 => "timm/vgg19.tv_in1k", + }; + let api = api.model(repo.into()); + let filename = "model.safetensors"; + let model_file = api.get(filename)?; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = match args.which { + Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?, + Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?, + Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?, + }; + let logits = model.forward(&image)?; + + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + + // Sort the predictions and take the top 5 + let mut top: Vec<_> = prs.iter().enumerate().collect(); + top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + let top = top.into_iter().take(5).collect::>(); + + // Print the top predictions + for &(i, p) in &top { + println!( + "{:50}: {:.2}%", + candle_examples::imagenet::CLASSES[i], + p * 100.0 + ); + } + + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index c59bd880cf..aecfcd672b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,6 +28,7 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; pub mod t5; +pub mod vgg; pub mod vit; pub mod whisper; pub mod with_tracing; diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs new file mode 100644 index 0000000000..7837dc3e69 --- /dev/null +++ b/candle-transformers/src/models/vgg.rs @@ -0,0 +1,254 @@ +//! VGG-16 model implementation. +//! +//! See Very Deep Convolutional Networks for Large-Scale Image Recognition +//! +use candle::{Module, Result, Tensor}; +use candle_nn::{Func, VarBuilder}; + +// Enum representing the different VGG models +pub enum Models { + Vgg13, + Vgg16, + Vgg19, +} + +// Struct representing a VGG model +#[derive(Debug)] +pub struct Vgg<'a> { + blocks: Vec>, +} + +// Struct representing the configuration for the pre-logit layer +struct PreLogitConfig { + in_dim: (usize, usize, usize, usize), + target_in: usize, + target_out: usize, +} + +// Implementation of the VGG model +impl<'a> Vgg<'a> { + // Function to create a new VGG model + pub fn new(vb: VarBuilder<'a>, model: Models) -> Result { + let blocks = match model { + Models::Vgg13 => vgg13_blocks(vb)?, + Models::Vgg16 => vgg16_blocks(vb)?, + Models::Vgg19 => vgg19_blocks(vb)?, + }; + Ok(Self { blocks }) + } +} + +// Implementation of the forward pass for the VGG model +impl Module for Vgg<'_> { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.unsqueeze(0)?; + for block in self.blocks.iter() { + xs = xs.apply(block)?; + } + Ok(xs) + } +} + +// Function to create a conv2d block +// The block is composed of two conv2d layers followed by a max pool layer +fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { + let layers = convs + .iter() + .enumerate() + .map(|(_, &(in_c, out_c, name))| { + candle_nn::conv2d( + in_c, + out_c, + 3, + candle_nn::Conv2dConfig { + stride: 1, + padding: 1, + ..Default::default() + }, + vb.pp(name), + ) + }) + .collect::>>()?; + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)?.relu()? + } + xs = xs.max_pool2d_with_stride(2, 2)?; + Ok(xs) + })) +} + +// Function to create a fully connected layer +// The layer is composed of two linear layers followed by a dropout layer +fn fully_connected( + num_classes: usize, + pre_logit_1: PreLogitConfig, + pre_logit_2: PreLogitConfig, + vb: VarBuilder, +) -> Result { + let lin = get_weights_and_biases( + &vb.pp("pre_logits.fc1"), + pre_logit_1.in_dim, + pre_logit_1.target_in, + pre_logit_1.target_out, + )?; + let lin2 = get_weights_and_biases( + &vb.pp("pre_logits.fc2"), + pre_logit_2.in_dim, + pre_logit_2.target_in, + pre_logit_2.target_out, + )?; + Ok(Func::new(move |xs| { + let xs = xs.reshape((1, pre_logit_1.target_out))?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?; + let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?; + Ok(xs) + })) +} + +// Function to get the weights and biases for a layer +// This is required because the weights and biases are stored in different format than our linear layer expects +fn get_weights_and_biases( + vs: &VarBuilder, + in_dim: (usize, usize, usize, usize), + target_in: usize, + target_out: usize, +) -> Result { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_with_hints(in_dim, "weight", init_ws)?; + let ws = ws.reshape((target_in, target_out))?; + let bound = 1. / (target_out as f64).sqrt(); + let init_bs = candle_nn::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_with_hints(target_in, "bias", init_bs)?; + Ok(candle_nn::Linear::new(ws, Some(bs))) +} + +fn vgg13_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?, + conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?, + conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} + +fn vgg16_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block( + &[ + (128, 256, "features.10"), + (256, 256, "features.12"), + (256, 256, "features.14"), + ], + &vb, + )?, + conv2d_block( + &[ + (256, 512, "features.17"), + (512, 512, "features.19"), + (512, 512, "features.21"), + ], + &vb, + )?, + conv2d_block( + &[ + (512, 512, "features.24"), + (512, 512, "features.26"), + (512, 512, "features.28"), + ], + &vb, + )?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} + +fn vgg19_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block( + &[ + (128, 256, "features.10"), + (256, 256, "features.12"), + (256, 256, "features.14"), + (256, 256, "features.16"), + ], + &vb, + )?, + conv2d_block( + &[ + (256, 512, "features.19"), + (512, 512, "features.21"), + (512, 512, "features.23"), + (512, 512, "features.25"), + ], + &vb, + )?, + conv2d_block( + &[ + (512, 512, "features.28"), + (512, 512, "features.30"), + (512, 512, "features.32"), + (512, 512, "features.34"), + ], + &vb, + )?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} From 55bc3382cfd3a86018c54f2343567f7c0c0b677c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 07:53:09 +0100 Subject: [PATCH 15/22] Allow for different behavior between training and eval (#1213) * Forward with training. * Do not use dropout on vgg evaluation. --- candle-core/src/lib.rs | 12 +++++++ candle-core/src/tensor.rs | 5 +++ .../examples/mnist-training/main.rs | 4 +-- candle-examples/examples/vgg/main.rs | 4 +-- candle-nn/src/func.rs | 35 +++++++++++++++++++ candle-nn/src/lib.rs | 4 +-- candle-nn/src/ops.rs | 6 ++++ candle-transformers/src/models/vgg.rs | 35 ++++++++++--------- 8 files changed, 83 insertions(+), 22 deletions(-) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 52effdcf80..73830229cf 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -125,3 +125,15 @@ impl Result> Module for T { self(xs) } } + +// A trait defining a module with forward method using a single tensor argument and a flag to +// separate the training and evaluation behaviors. +pub trait ModuleT { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result; +} + +impl ModuleT for M { + fn forward_t(&self, xs: &Tensor, _train: bool) -> Result { + self.forward(xs) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ce81d8aff0..c6f2364d60 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2271,6 +2271,11 @@ impl Tensor { m.forward(self) } + /// Run the `forward` method of `m` on `self`. + pub fn apply_t(&self, m: &M, train: bool) -> Result { + m.forward_t(self, train) + } + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index a07505bf46..a41a6496b9 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum}; use rand::prelude::*; use candle::{DType, Result, Tensor, D}; -use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap}; +use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; @@ -95,7 +95,7 @@ impl ConvNet { .flatten_from(1)? .apply(&self.fc1)? .relu()?; - self.dropout.forward(&xs, train)?.apply(&self.fc2) + self.dropout.forward_t(&xs, train)?.apply(&self.fc2) } } diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs index e01fa8e8b5..27e141cb95 100644 --- a/candle-examples/examples/vgg/main.rs +++ b/candle-examples/examples/vgg/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{DType, IndexOp, D}; -use candle_nn::{Module, VarBuilder}; +use candle_nn::{ModuleT, VarBuilder}; use candle_transformers::models::vgg::{Models, Vgg}; use clap::{Parser, ValueEnum}; @@ -53,7 +53,7 @@ pub fn main() -> anyhow::Result<()> { Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?, Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?, }; - let logits = model.forward(&image)?; + let logits = model.forward_t(&image, /*train=*/ false)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? .i(0)? diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index 39311d458c..3adfda860d 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -36,3 +36,38 @@ impl<'a> Func<'a> { Self { f: Arc::new(f) } } } + +/// A layer defined by a simple closure. +#[derive(Clone)] +pub struct FuncT<'a> { + #[allow(clippy::type_complexity)] + f: Arc Result + Send + Sync>, +} + +impl<'a> std::fmt::Debug for FuncT<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func_t<'a, F>(f: F) -> FuncT<'a> +where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, +{ + FuncT { f: Arc::new(f) } +} + +impl<'a> super::ModuleT for FuncT<'a> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + (*self.f)(xs, train) + } +} + +impl<'a> FuncT<'a> { + pub fn new(f: F) -> Self + where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, + { + Self { f: Arc::new(f) } + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index be95f53121..52d8f0c595 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -22,7 +22,7 @@ pub use conv::{ Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig, }; pub use embedding::{embedding, Embedding}; -pub use func::{func, Func}; +pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; @@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; -pub use candle::Module; +pub use candle::{Module, ModuleT}; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 32de1af9c8..e98121083e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -84,6 +84,12 @@ impl Dropout { } } +impl candle::ModuleT for Dropout { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + self.forward(xs, train) + } +} + struct SoftmaxLastDim; impl candle::CustomOp1 for SoftmaxLastDim { diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index 7837dc3e69..a20b5e3725 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -2,8 +2,8 @@ //! //! See Very Deep Convolutional Networks for Large-Scale Image Recognition //! -use candle::{Module, Result, Tensor}; -use candle_nn::{Func, VarBuilder}; +use candle::{ModuleT, Result, Tensor}; +use candle_nn::{FuncT, VarBuilder}; // Enum representing the different VGG models pub enum Models { @@ -15,7 +15,7 @@ pub enum Models { // Struct representing a VGG model #[derive(Debug)] pub struct Vgg<'a> { - blocks: Vec>, + blocks: Vec>, } // Struct representing the configuration for the pre-logit layer @@ -39,11 +39,11 @@ impl<'a> Vgg<'a> { } // Implementation of the forward pass for the VGG model -impl Module for Vgg<'_> { - fn forward(&self, xs: &Tensor) -> Result { +impl ModuleT for Vgg<'_> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { let mut xs = xs.unsqueeze(0)?; for block in self.blocks.iter() { - xs = xs.apply(block)?; + xs = xs.apply_t(block, train)?; } Ok(xs) } @@ -51,7 +51,7 @@ impl Module for Vgg<'_> { // Function to create a conv2d block // The block is composed of two conv2d layers followed by a max pool layer -fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { +fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { let layers = convs .iter() .enumerate() @@ -70,7 +70,7 @@ fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result>>()?; - Ok(Func::new(move |xs| { + Ok(FuncT::new(move |xs, _train| { let mut xs = xs.clone(); for layer in layers.iter() { xs = xs.apply(layer)?.relu()? @@ -87,7 +87,7 @@ fn fully_connected( pre_logit_1: PreLogitConfig, pre_logit_2: PreLogitConfig, vb: VarBuilder, -) -> Result { +) -> Result { let lin = get_weights_and_biases( &vb.pp("pre_logits.fc1"), pre_logit_1.in_dim, @@ -100,12 +100,15 @@ fn fully_connected( pre_logit_2.target_in, pre_logit_2.target_out, )?; - Ok(Func::new(move |xs| { + let dropout1 = candle_nn::Dropout::new(0.5); + let dropout2 = candle_nn::Dropout::new(0.5); + let dropout3 = candle_nn::Dropout::new(0.5); + Ok(FuncT::new(move |xs, train| { let xs = xs.reshape((1, pre_logit_1.target_out))?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?; + let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?; + let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?; let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?; + let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?; Ok(xs) })) } @@ -130,7 +133,7 @@ fn get_weights_and_biases( Ok(candle_nn::Linear::new(ws, Some(bs))) } -fn vgg13_blocks(vb: VarBuilder) -> Result> { +fn vgg13_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, @@ -156,7 +159,7 @@ fn vgg13_blocks(vb: VarBuilder) -> Result> { Ok(blocks) } -fn vgg16_blocks(vb: VarBuilder) -> Result> { +fn vgg16_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, @@ -203,7 +206,7 @@ fn vgg16_blocks(vb: VarBuilder) -> Result> { Ok(blocks) } -fn vgg19_blocks(vb: VarBuilder) -> Result> { +fn vgg19_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, From 46d6566c99f63fc74f3fbf5754183a49219224d5 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 10:50:04 +0100 Subject: [PATCH 16/22] Fix the conv2d gradient computation. (#1214) --- candle-core/src/backprop.rs | 7 ++++ candle-core/tests/conv_tests.rs | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7488d93979..155f49c5ca 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -238,6 +238,13 @@ impl Tensor { .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; *sum_grad = sum_grad.add(&grad_kernel)?; } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 937ddf6761..e7fdf1381a 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> { ] ] ); + + // Replicate the issue from https://github.com/huggingface/candle/issues/1212 + let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 2)?, + [ + [ + [9.29, -7.03, 7.87, 0.0, 0.0], + [-1.8, -7.82, 5.9, 0.0, 0.0], + [-3.12, 4.49, 5.52, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [21.73, 3.39, 4.77, 0.0, 0.0], + [8.25, 3.73, 27.61, 0.0, 0.0], + [-20.55, -5.61, -2.77, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [-8.98, 9.91, -7.15, 0.0, 0.0], + [4.93, -0.33, 4.56, 0.0, 0.0], + [-6.7, -5.76, -8.05, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [23.54, 6.98, -10.0, 0.0, 0.0], + [9.65, 6.18, 18.72, 0.0, 0.0], + [3.29, -5.27, 0.79, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ] + ] + ); + assert_eq!( + test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, + [ + [ + [-3.47, 7.44, 0.66], + [12.89, -3.4, -9.29], + [-14.16, -0.83, 7.14] + ], + [ + [-3.23, 5.37, -3.02], + [-2.12, -11.24, 1.94], + [6.97, 7.2, 2.99] + ], + [ + [-4.04, -3.31, 4.87], + [-6.68, -5.68, 1.73], + [-5.54, 4.32, 0.52] + ], + [[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]] + ] + ); + Ok(()) } From c3f2676d4932daaa5aa7e1bb7faad343ad54d36e Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 29 Oct 2023 14:44:05 +0100 Subject: [PATCH 17/22] PyO3: Add CI to build & upload wheels as artifacts. (#1215) * Add maturin ci * fix paths * Change sdist path --- .github/workflows/maturin.yml | Bin 0 -> 5304 bytes candle-pyo3/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 .github/workflows/maturin.yml diff --git a/.github/workflows/maturin.yml b/.github/workflows/maturin.yml new file mode 100644 index 0000000000000000000000000000000000000000..1413f01475fb1a0baea85c10e3c05c82e19b415d GIT binary patch literal 5304 zcmeH~-EPxB5QXO&iFe2aQlyGBRIOUz3h@GPgSbElrHPvc@`vpZ5aQK=^X+z$wcRFF zMWYHuR$^y8J9FmmWcJVRsr77PnZ2}@y|o|q#*VFH@9k1+nT@Tm$Mz_EW;T@+zgoKH zw$QuFWQEmp%cXB>{jk5Ny+xv<&qOjKNx3f8ORWv1aczNB-_f=MYggpwk}qZDrXBr& zV;~PQ*__L>nLO)C&%sI$K8$sJ66(yp>Q^RxWevl>u(Xu*+`ia_tj%mGivvZV5H7qR zTG4MJG8c+mG2(rpZ{nVM*$*qFq^8=-n^wTzlTOdoXUUvbc8x>C7xum8T`sTD-w9gL zwa`x1N_}^P7lbh`X}*XPx#rza(QW39EF&{&*!0Yj^IW?#$zl0}B%j-mm&`)&dSCvT}lV!A>G_;{{DGW$8qwL73aM#n39%{&| z2T$6U7Z^{F&g>9tF4}#??6SzIUr+zwrbVG;ZA=#hG%QTL7d2~lF>lO3F ze(H*h;2&rlKh^}p^ zDyVDp*V(e}1Kp9<$vb%rpTRuT=dt%pMP=>^wBwd6uLQjDYA^<746w)aJfcN;=9 zI@@oKOm|fc#K_|W8lnip*V(1XvsLNcGp+Mz4~>8KOj>3B`({%9@0fG9v*=oK_Y7K{RjU)~o=-RP>7MhLFdp*Ex#zsya$d2T zDXXWv=Hz($6Z-w1bJO`com%{dJaF65V literal 0 HcmV?d00001 diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 0241d2b29a..488404bf06 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,7 +19,7 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.19.0", features = ["extension-module"] } +pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] pyo3-build-config = "0.19" From 7bbde55c61d9bff90c9f7d0005ed17bbea4b4a8f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 16:12:22 +0100 Subject: [PATCH 18/22] Marian MT model (#1210) * Skeleton files for the marian MT model. * Marian initialization. * Implement the attention forward method. * Forward pass for the encoder side. * Expose the encoder and decoder. * Start plugging the decoder. * Forward pass for the decoder layer. * Set up the marian example. * Add some missing backtraces. * Bugfix. --- candle-core/src/cpu_backend.rs | 20 +- candle-examples/examples/marian-mt/main.rs | 90 ++++ candle-transformers/src/models/marian.rs | 413 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + .../src/models/with_tracing.rs | 7 + 5 files changed, 521 insertions(+), 10 deletions(-) create mode 100644 candle-examples/examples/marian-mt/main.rs create mode 100644 candle-transformers/src/models/marian.rs diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 86cbeb78ab..e9ff464124 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> { fn f(&self, src: &[T], src_l: &Layout) -> Result> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], - None => Err(Error::RequiresContiguous { op: "gather" })?, + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, }; let src = match src_l.contiguous_offsets() { Some((a, b)) => &src[a..b], - None => Err(Error::RequiresContiguous { op: "gather" })?, + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, }; let dim = self.dim; let ids_dims = self.ids_l.dims(); @@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> { fn f(&self, src: &[T], layout: &Layout) -> Result> { let src = match layout.contiguous_offsets() { Some((a, b)) => &src[a..b], - None => Err(Error::RequiresContiguous { op: "index-select" })?, + None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?, }; let dim = self.dim; let n_ids = match self.ids_l.dims() { @@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { let mut dst = vec![T::zero(); dst_len]; copy_strided_src_(v1, &mut dst, 0, l1); let src = match src_l.contiguous_offsets() { - None => Err(Error::RequiresContiguous { op: "scatter-add" })?, + None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?, Some((o1, o2)) => &src[o1..o2], }; @@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], - None => Err(Error::RequiresContiguous { op: "gather" })?, + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, }; for left_i in 0..ids_left_len { let start_ids_idx = left_i * ids_right_len * ids_dim_len; @@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { let mut dst = vec![T::zero(); dst_len]; copy_strided_src_(v1, &mut dst, 0, l1); let src = match src_l.contiguous_offsets() { - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, Some((o1, o2)) => &src[o1..o2], }; let dim = self.dim; @@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage { Self::U8(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, }; IndexAdd { ids, dim }.map(self, l, src, src_l) } Self::U32(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, }; IndexAdd { ids, dim }.map(self, l, src, src_l) } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, }; IndexAdd { ids, dim }.map(self, l, src, src_l) } - _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()), } } diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs new file mode 100644 index 0000000000..ed044627c6 --- /dev/null +++ b/candle-examples/examples/marian-mt/main.rs @@ -0,0 +1,90 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::Parser; + +use candle::{DType, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::models::marian; + +use tokenizers::Tokenizer; + +// TODO: Maybe add support for the conditional prompt. +#[derive(Parser)] +struct Args { + #[arg(long)] + model: String, + + #[arg(long)] + tokenizer: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, + + /// Text to be translated + #[arg(long)] + text: String, +} + +const SEP_TOKEN_ID: u32 = 102; + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let config = marian::Config::opus_mt_tc_big_fr_en(); + + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? }; + let model = marian::MTModel::new(&config, vb)?; + + let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?; + let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone()); + let mut logits_processor = + candle_transformers::generation::LogitsProcessor::new(1337, None, None); + + let encoder_xs = { + let tokens = tokenizer + .encode(args.text, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + model.encoder().forward(&tokens, 0)? + }; + + let mut token_ids = vec![30522u32]; + for index in 0..1000 { + // TODO: Add a kv cache. + let context_size = if index >= 1000 { 1 } else { token_ids.len() }; + let start_pos = token_ids.len().saturating_sub(context_size); + let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; + let logits = model.decode(&input_ids, &encoder_xs)?; + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + if token == SEP_TOKEN_ID { + break; + } + token_ids.push(token); + if let Some(t) = tokenizer_dec.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + + Ok(()) +} diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs new file mode 100644 index 0000000000..d48ce38b11 --- /dev/null +++ b/candle-transformers/src/models/marian.rs @@ -0,0 +1,413 @@ +#![allow(unused)] +use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; +use candle::{Module, Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, VarBuilder}; + +#[derive(Debug, Clone)] +pub struct Config { + pub vocab_size: usize, + pub decoder_vocab_size: Option, + pub max_position_embeddings: usize, + pub encoder_layers: usize, + pub encoder_ffn_dim: usize, + pub encoder_attention_heads: usize, + pub decoder_layers: usize, + pub decoder_ffn_dim: usize, + pub decoder_attention_heads: usize, + pub use_cache: bool, + pub is_encoder_decoder: bool, + pub activation_function: candle_nn::Activation, + pub d_model: usize, + pub decoder_start_token_id: usize, + pub scale_embedding: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, + pub forced_eos_token_id: usize, + pub share_encoder_decoder_embeddings: bool, +} + +impl Config { + // https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en/blob/main/config.json + pub fn opus_mt_tc_big_fr_en() -> Self { + Self { + activation_function: candle_nn::Activation::Relu, + d_model: 1024, + decoder_attention_heads: 16, + decoder_ffn_dim: 4096, + decoder_layers: 6, + decoder_start_token_id: 53016, + decoder_vocab_size: Some(53017), + encoder_attention_heads: 16, + encoder_ffn_dim: 4096, + encoder_layers: 6, + eos_token_id: 43311, + forced_eos_token_id: 43311, + is_encoder_decoder: true, + max_position_embeddings: 1024, + pad_token_id: 53016, + scale_embedding: true, + share_encoder_decoder_embeddings: true, + use_cache: true, + vocab_size: 53017, + } + } +} + +#[derive(Debug, Clone)] +struct SinusoidalPositionalEmbedding { + emb: Embedding, +} + +impl SinusoidalPositionalEmbedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dev = vb.device(); + let dtype = vb.dtype(); + let num_positions = cfg.max_position_embeddings; + let dim = cfg.d_model; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, num_positions as u32, dev)? + .to_dtype(dtype)? + .reshape((num_positions, 1))?; + let freqs = t.matmul(&inv_freq)?; + let sin = freqs.sin()?; + let cos = freqs.cos()?; + let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?; + let emb = Embedding::from_weights(weights)?; + Ok(Self { emb }) + } + + fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result { + let seq_len = input_ids.dim(1)?; + Tensor::arange( + past_kv_len as u32, + (past_kv_len + seq_len) as u32, + input_ids.device(), + )? + .apply(&self.emb) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + scaling: f64, + num_heads: usize, + head_dim: usize, +} + +impl Attention { + fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result { + let num_heads = if is_decoder { + cfg.decoder_attention_heads + } else { + cfg.encoder_attention_heads + }; + let embed_dim = cfg.d_model; + let head_dim = embed_dim / num_heads; + let scaling = (head_dim as f64).powf(-0.5); + let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?; + let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?; + let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?; + let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + scaling, + num_heads, + head_dim, + }) + } + + fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result { + tensor + .reshape((bsz, (), self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result { + let is_cross_attn = kv_states.is_some(); + let (b_sz, tgt_len, _) = xs.dims3()?; + let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; + let (key_states, value_states) = match kv_states { + None => { + let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?; + (key_states, value_states) + } + Some(kv_states) => { + let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?; + (key_states, value_states) + } + }; + let proj_shape = (b_sz * self.num_heads, (), self.head_dim); + let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?; + let key_states = key_states.reshape(proj_shape)?; + let value_states = value_states.reshape(proj_shape)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + // todo: attn_mask + let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_probs.matmul(&value_states)?; + attn_output + .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))? + .transpose(1, 2)? + .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))? + .apply(&self.out_proj) + } +} + +#[derive(Debug, Clone)] +struct EncoderLayer { + self_attn: Attention, + self_attn_layer_norm: LayerNorm, + activation_fn: candle_nn::Activation, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl EncoderLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?; + let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?; + let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; + Ok(Self { + self_attn, + self_attn_layer_norm, + activation_fn: cfg.activation_function, + fc1, + fc2, + final_layer_norm, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = + (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let residual = &xs; + let xs = xs + .apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2)?; + (xs + residual)?.apply(&self.final_layer_norm) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + self_attn_layer_norm: LayerNorm, + activation_fn: candle_nn::Activation, + encoder_attn: Attention, + encoder_attn_layer_norm: LayerNorm, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl DecoderLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?; + let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?; + let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?; + let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; + Ok(Self { + self_attn, + self_attn_layer_norm, + activation_fn: cfg.activation_function, + encoder_attn, + encoder_attn_layer_norm, + fc1, + fc2, + final_layer_norm, + }) + } + + fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result { + let residual = xs; + let xs = + (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let xs = match encoder_xs { + None => xs, + Some(encoder_xs) => { + let residual = &xs; + let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?; + (residual + xs)?.apply(&self.self_attn_layer_norm)? + } + }; + let residual = &xs; + let xs = xs + .apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2)?; + (xs + residual)?.apply(&self.final_layer_norm) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + embed_tokens: Embedding, + embed_positions: SinusoidalPositionalEmbedding, + layers: Vec, + embed_scale: Option, +} + +impl Encoder { + fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result { + let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?; + let mut layers = Vec::with_capacity(cfg.encoder_layers); + let vb_l = vb.pp("layers"); + for idx in 0..cfg.encoder_layers { + let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?; + layers.push(layer) + } + let embed_scale = if cfg.scale_embedding { + Some((cfg.d_model as f64).sqrt()) + } else { + None + }; + Ok(Self { + embed_tokens: embed_tokens.clone(), + embed_positions, + layers, + embed_scale, + }) + } + + pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result { + let xs = xs.apply(&self.embed_tokens)?; + let xs = match self.embed_scale { + None => xs, + Some(scale) => (xs * scale)?, + }; + let embed_pos = self + .embed_positions + .forward(&xs, past_kv_len)? + .unsqueeze(0)?; + let mut xs = xs.broadcast_add(&embed_pos)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Decoder { + embed_tokens: Embedding, + embed_positions: SinusoidalPositionalEmbedding, + layers: Vec, + embed_scale: Option, +} + +impl Decoder { + fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result { + let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?; + let mut layers = Vec::with_capacity(cfg.decoder_layers); + let vb_l = vb.pp("layers"); + for idx in 0..cfg.decoder_layers { + let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?; + layers.push(layer) + } + let embed_scale = if cfg.scale_embedding { + Some((cfg.d_model as f64).sqrt()) + } else { + None + }; + Ok(Self { + embed_tokens: embed_tokens.clone(), + embed_positions, + layers, + embed_scale, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + past_kv_len: usize, + ) -> Result { + let xs = xs.apply(&self.embed_tokens)?; + let xs = match self.embed_scale { + None => xs, + Some(scale) => (xs * scale)?, + }; + let embed_pos = self + .embed_positions + .forward(&xs, past_kv_len)? + .unsqueeze(0)?; + let mut xs = xs.broadcast_add(&embed_pos)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs, encoder_xs)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Model { + shared: Embedding, + encoder: Encoder, + decoder: Decoder, +} + +impl Model { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?; + let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?; + Ok(Self { + shared, + encoder, + decoder, + }) + } +} + +#[derive(Debug, Clone)] +pub struct MTModel { + model: Model, + final_logits_bias: Tensor, +} + +impl MTModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size); + let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?; + let model = Model::new(cfg, vb.pp("model"))?; + Ok(Self { + model, + final_logits_bias, + }) + } + + pub fn encoder(&self) -> &Encoder { + &self.model.encoder + } + + pub fn decoder(&self) -> &Decoder { + &self.model.decoder + } + + pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result { + self.model.decoder.forward(xs, Some(encoder_xs), 0) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index aecfcd672b..370b9108df 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -10,6 +10,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod marian; pub mod mistral; pub mod mixformer; pub mod mpt; diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 39258085d1..a657011c34 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -14,6 +14,13 @@ impl Embedding { Ok(Self { inner, span }) } + pub fn from_weights(weights: Tensor) -> Result { + let (_in_size, out_size) = weights.dims2()?; + let inner = candle_nn::Embedding::new(weights, out_size); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + pub fn embeddings(&self) -> &Tensor { self.inner.embeddings() } From 154c674a798fd5a40d57ff9a8664856d9c41ca56 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 16:28:53 +0100 Subject: [PATCH 19/22] Add i64-abs. (#1216) --- candle-core/src/op.rs | 35 ++++++++++++++++++++++++++++++- candle-core/tests/tensor_tests.rs | 8 +++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b7f99f115a..e1168c2e46 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); -unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); @@ -666,6 +665,40 @@ impl UnaryOpT for Erf { } } +impl UnaryOpT for Abs { + const NAME: &'static str = "abs"; + const KERNEL: &'static str = "uabs"; + const V: Self = Abs; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.abs() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.abs() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.abs() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.abs() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v.abs() + } +} + impl UnaryOpT for Ceil { const NAME: &'static str = "ceil"; const KERNEL: &'static str = "uceil"; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index ae1bd0581b..899efcf3a4 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1089,3 +1089,11 @@ fn pad_with_same() -> Result<()> { ); Ok(()) } + +#[test] +fn i64_abs() -> Result<()> { + let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?; + let t = t.abs()?; + assert_eq!(t.to_vec1::()?, [42, 1337]); + Ok(()) +} From 174b20805230abaf91b838598d84ab142f31a975 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 29 Oct 2023 16:41:44 +0100 Subject: [PATCH 20/22] PyO3: Better shape handling (#1143) * Negative and `*args` shape handling * Rename to `PyShapeWithHole` + validate that only one hole exists * Regenerate stubs --------- Co-authored-by: Laurent Mazare --- candle-examples/Cargo.toml | 2 +- candle-pyo3/Cargo.toml | 4 +- candle-pyo3/py_src/candle/__init__.pyi | 16 +-- .../py_src/candle/functional/__init__.pyi | 2 +- candle-pyo3/py_src/candle/typing/__init__.py | 2 + candle-pyo3/py_src/candle/utils/__init__.pyi | 2 +- candle-pyo3/src/lib.rs | 71 +++++++------ candle-pyo3/src/shape.rs | 99 +++++++++++++++++++ candle-pyo3/stub.py | 2 +- candle-pyo3/tests/native/test_shape.py | 31 ++++++ 10 files changed, 181 insertions(+), 50 deletions(-) create mode 100644 candle-pyo3/src/shape.rs create mode 100644 candle-pyo3/tests/native/test_shape.py diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 7372e24f20..b1913541fb 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -21,7 +21,7 @@ half = { workspace = true, optional = true } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } -pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 488404bf06..b04524040e 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] -pyo3-build-config = "0.19" +pyo3-build-config = "0.20" [features] default = [] diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 437221683b..35b17680b9 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device, Scalar, Index +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape class bf16(DType): pass @@ -26,21 +26,21 @@ class i64(DType): pass @staticmethod -def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: +def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor filled with ones. """ pass @staticmethod -def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: +def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor with random values. """ pass @staticmethod -def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: +def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor with random values from a normal distribution. """ @@ -67,7 +67,7 @@ class u8(DType): pass @staticmethod -def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: +def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor filled with zeros. """ @@ -174,7 +174,7 @@ class Tensor: Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass - def broadcast_as(self, shape: Sequence[int]) -> Tensor: + def broadcast_as(self, *shape: Shape) -> Tensor: """ Broadcasts the tensor to the given shape. """ @@ -184,7 +184,7 @@ class Tensor: Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass - def broadcast_left(self, shape: Sequence[int]) -> Tensor: + def broadcast_left(self, *shape: Shape) -> Tensor: """ Broadcasts the tensor to the given shape, adding new dimensions on the left. """ @@ -329,7 +329,7 @@ class Tensor: Get the `recip` of the tensor. """ pass - def reshape(self, shape: Sequence[int]) -> Tensor: + def reshape(self, *shape: Shape) -> Tensor: """ Reshapes the tensor to the given shape. """ diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi index 5bf5c4c31f..4f7c2aa65a 100644 --- a/candle-pyo3/py_src/candle/functional/__init__.pyi +++ b/candle-pyo3/py_src/candle/functional/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device, Scalar, Index +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape from candle import Tensor, DType, QTensor @staticmethod diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py index 66bc3d8aba..b2262a97d0 100644 --- a/candle-pyo3/py_src/candle/typing/__init__.py +++ b/candle-pyo3/py_src/candle/typing/__init__.py @@ -18,3 +18,5 @@ Scalar = Union[int, float] Index = Union[int, slice, None, "Ellipsis"] + +Shape = Union[int, Sequence[int]] diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index d3b9376675..4ee51c290b 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device, Scalar, Index +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape from candle import Tensor, DType, QTensor @staticmethod diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 6d4de80bfb..41c4577fee 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -16,26 +16,13 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +mod shape; +use shape::{PyShape, PyShapeWithHole}; + pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } -#[derive(Clone, Debug)] -struct PyShape(Vec); - -impl<'source> pyo3::FromPyObject<'source> for PyShape { - fn extract(ob: &'source PyAny) -> PyResult { - let dims: Vec = pyo3::FromPyObject::extract(ob)?; - Ok(PyShape(dims)) - } -} - -impl From for ::candle::Shape { - fn from(val: PyShape) -> Self { - val.0.into() - } -} - #[derive(Clone, Debug)] #[pyclass(name = "Tensor")] /// A `candle` tensor. @@ -684,25 +671,37 @@ impl PyTensor { Ok(Self(tensor)) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Reshapes the tensor to the given shape. /// &RETURNS&: Tensor - fn reshape(&self, shape: PyShape) -> PyResult { - Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) + fn reshape(&self, shape: PyShapeWithHole) -> PyResult { + Ok(PyTensor( + self.0 + .reshape(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Broadcasts the tensor to the given shape. /// &RETURNS&: Tensor - fn broadcast_as(&self, shape: PyShape) -> PyResult { - Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) + fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult { + Ok(PyTensor( + self.0 + .broadcast_as(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Broadcasts the tensor to the given shape, adding new dimensions on the left. /// &RETURNS&: Tensor - fn broadcast_left(&self, shape: PyShape) -> PyResult { - Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) + fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult { + Ok(PyTensor( + self.0 + .broadcast_left(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } #[pyo3(text_signature = "(self, dim:int)")] @@ -915,21 +914,21 @@ impl PyTensor { } if let Some(kwargs) = kwargs { - if let Some(any) = kwargs.get_item("dtype") { + if let Ok(Some(any)) = kwargs.get_item("dtype") { handle_duplicates( &mut dtype, any.extract::(), "cannot specify multiple dtypes", )?; } - if let Some(any) = kwargs.get_item("device") { + if let Ok(Some(any)) = kwargs.get_item("device") { handle_duplicates( &mut device, any.extract::(), "cannot specify multiple devices", )?; } - if let Some(any) = kwargs.get_item("other") { + if let Ok(Some(any)) = kwargs.get_item("other") { handle_duplicates( &mut other, any.extract::(), @@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult { } #[pyfunction] -#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")] /// Creates a new tensor with random values. /// &RETURNS&: Tensor fn rand(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")] /// Creates a new tensor with random values from a normal distribution. /// &RETURNS&: Tensor fn randn(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")] /// Creates a new tensor filled with ones. /// &RETURNS&: Tensor fn ones( @@ -1083,12 +1082,12 @@ fn ones( Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; + let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")] /// Creates a new tensor filled with zeros. /// &RETURNS&: Tensor fn zeros( @@ -1102,7 +1101,7 @@ fn zeros( Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; + let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs new file mode 100644 index 0000000000..2668b7331b --- /dev/null +++ b/candle-pyo3/src/shape.rs @@ -0,0 +1,99 @@ +use ::candle::Tensor; +use pyo3::prelude::*; + +#[derive(Clone, Debug)] +/// Represents an absolute shape e.g. (1, 2, 3) +pub struct PyShape(Vec); + +impl<'source> pyo3::FromPyObject<'source> for PyShape { + fn extract(ob: &'source PyAny) -> PyResult { + if ob.is_none() { + return Err(PyErr::new::( + "Shape cannot be None", + )); + } + + let tuple = ob.downcast::()?; + if tuple.len() == 1 { + let first_element = tuple.get_item(0)?; + let dims: Vec = pyo3::FromPyObject::extract(first_element)?; + Ok(PyShape(dims)) + } else { + let dims: Vec = pyo3::FromPyObject::extract(tuple)?; + Ok(PyShape(dims)) + } + } +} + +impl From for ::candle::Shape { + fn from(val: PyShape) -> Self { + val.0.into() + } +} + +#[derive(Clone, Debug)] +/// Represents a shape with a hole in it e.g. (1, -1, 3) +pub struct PyShapeWithHole(Vec); + +impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { + fn extract(ob: &'source PyAny) -> PyResult { + if ob.is_none() { + return Err(PyErr::new::( + "Shape cannot be None", + )); + } + + let tuple = ob.downcast::()?; + let dims: Vec = if tuple.len() == 1 { + let first_element = tuple.get_item(0)?; + pyo3::FromPyObject::extract(first_element)? + } else { + pyo3::FromPyObject::extract(tuple)? + }; + + // Ensure we have only positive numbers and at most one "hole" (-1) + let negative_ones = dims.iter().filter(|&&x| x == -1).count(); + let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0); + if negative_ones > 1 || any_invalid_dimensions { + return Err(PyErr::new::(format!( + "Invalid dimension in shape: {:?}", + dims + ))); + } + + Ok(PyShapeWithHole(dims)) + } +} + +impl PyShapeWithHole { + /// Returns `true` if the shape is absolute e.g. (1, 2, 3) + pub fn is_absolute(&self) -> bool { + self.0.iter().all(|x| *x > 0) + } + + /// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12) + pub fn to_absolute(&self, t: &Tensor) -> PyResult { + if self.is_absolute() { + return Ok(PyShape( + self.0.iter().map(|x| *x as usize).collect::>(), + )); + } + + let mut elements = t.elem_count(); + let mut new_dims: Vec = vec![]; + for dim in self.0.iter() { + if *dim > 0 { + new_dims.push(*dim as usize); + elements /= *dim as usize; + } else if *dim == -1 { + new_dims.push(elements); + } else { + return Err(PyErr::new::(format!( + "Invalid dimension in shape: {}", + dim + ))); + } + } + Ok(PyShape(new_dims)) + } +} diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index 8e4318bcf5..336f674ba8 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -13,7 +13,7 @@ TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike """ -CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n" +CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n" CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" RETURN_TYPE_MARKER = "&RETURNS&: " ADDITIONAL_TYPEHINTS = {} diff --git a/candle-pyo3/tests/native/test_shape.py b/candle-pyo3/tests/native/test_shape.py new file mode 100644 index 0000000000..864e24d679 --- /dev/null +++ b/candle-pyo3/tests/native/test_shape.py @@ -0,0 +1,31 @@ +from candle import Tensor +from candle import rand +import pytest + + +def test_absolute_shapes_are_valid(): + a = rand((10, 20)) + assert a.shape == (10, 20) + + b = rand(10, 20) + assert b.shape == (10, 20) + pytest.raises(OverflowError, lambda: rand((10, 20, -1))) + pytest.raises(OverflowError, lambda: rand(-1, 20)) + pytest.raises(TypeError, lambda: rand("foo", True)) + + +def test_relative_shapes_are_valid(): + a = rand(10, 20) + a = a.reshape((1, -1)) + assert a.shape == (1, 200) + + b = rand(10, 20) + b = b.reshape(-1, 1) + assert b.shape == (200, 1) + + c = rand(10, 20) + pytest.raises(TypeError, lambda: c.reshape(1, "foo")) + pytest.raises(ValueError, lambda: c.reshape(1, -2)) + pytest.raises(ValueError, lambda: c.reshape((-2, 1))) + pytest.raises(ValueError, lambda: c.reshape((0, 1))) + pytest.raises(ValueError, lambda: c.reshape((1, -1, -1))) From 5fc66bd4baa49e14d26c1955e1e7e9c505bc2544 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 30 Oct 2023 08:40:54 +0100 Subject: [PATCH 21/22] Support negative steps in arange. (#1218) --- candle-core/src/tensor.rs | 16 +++++++++++++--- candle-core/tests/tensor_tests.rs | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c6f2364d60..adcdc59d59 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -385,11 +385,21 @@ impl Tensor { step: D, device: &Device, ) -> Result { + if D::is_zero(&step) { + crate::bail!("step cannot be zero") + } let mut data = vec![]; let mut current = start; - while current < end { - data.push(current); - current += step; + if step >= D::zero() { + while current < end { + data.push(current); + current += step; + } + } else { + while current > end { + data.push(current); + current += step; + } } let len = data.len(); Self::from_vec_impl(data, len, device, false) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 899efcf3a4..734cb7e85f 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::F64, device)?.to_vec2::()?, [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ); + Ok(()) +} +fn arange(device: &Device) -> Result<()> { + assert_eq!( + Tensor::arange(0u8, 5u8, device)?.to_vec1::()?, + [0, 1, 2, 3, 4], + ); + assert_eq!( + Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::()?, + [0, 2, 4], + ); + assert_eq!( + Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::()?, + [0, 3], + ); + assert_eq!( + Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::()?, + [5, 4, 3, 2, 1], + ); Ok(()) } @@ -1037,6 +1056,7 @@ fn randn(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(ones, ones_cpu, ones_gpu); +test_device!(arange, arange_cpu, arange_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); test_device!(narrow, narrow_cpu, narrow_gpu); From 969960847ac7fd4959e8718d1355abb1f9f4385d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 30 Oct 2023 12:44:19 +0100 Subject: [PATCH 22/22] Bugfixes for marian-mt. (#1219) * Bugfixes for marian-mt. * Apply the final decoding head. * More fixes. --- candle-examples/examples/marian-mt/main.rs | 7 +++--- candle-transformers/src/models/marian.rs | 27 ++++++++++++++-------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs index ed044627c6..bf33743c74 100644 --- a/candle-examples/examples/marian-mt/main.rs +++ b/candle-examples/examples/marian-mt/main.rs @@ -36,8 +36,6 @@ struct Args { text: String, } -const SEP_TOKEN_ID: u32 = 102; - pub fn main() -> anyhow::Result<()> { let args = Args::parse(); @@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> { model.encoder().forward(&tokens, 0)? }; - let mut token_ids = vec![30522u32]; + let mut token_ids = vec![config.decoder_start_token_id]; for index in 0..1000 { // TODO: Add a kv cache. let context_size = if index >= 1000 { 1 } else { token_ids.len() }; @@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> { let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; - if token == SEP_TOKEN_ID { + println!("{token}"); + if token == config.eos_token_id || token == config.forced_eos_token_id { break; } token_ids.push(token); diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index d48ce38b11..71f177200b 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -18,11 +18,11 @@ pub struct Config { pub is_encoder_decoder: bool, pub activation_function: candle_nn::Activation, pub d_model: usize, - pub decoder_start_token_id: usize, + pub decoder_start_token_id: u32, pub scale_embedding: bool, - pub pad_token_id: usize, - pub eos_token_id: usize, - pub forced_eos_token_id: usize, + pub pad_token_id: u32, + pub eos_token_id: u32, + pub forced_eos_token_id: u32, pub share_encoder_decoder_embeddings: bool, } @@ -224,7 +224,8 @@ impl DecoderLayer { let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?; let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?; - let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn_layer_norm = + layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?; let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?; let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; @@ -249,7 +250,7 @@ impl DecoderLayer { Some(encoder_xs) => { let residual = &xs; let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?; - (residual + xs)?.apply(&self.self_attn_layer_norm)? + (residual + xs)?.apply(&self.encoder_attn_layer_norm)? } }; let residual = &xs; @@ -257,7 +258,8 @@ impl DecoderLayer { .apply(&self.fc1)? .apply(&self.activation_fn)? .apply(&self.fc2)?; - (xs + residual)?.apply(&self.final_layer_norm) + let xs = (xs + residual)?.apply(&self.final_layer_norm)?; + Ok(xs) } } @@ -356,7 +358,7 @@ impl Decoder { .unsqueeze(0)?; let mut xs = xs.broadcast_add(&embed_pos)?; for layer in self.layers.iter() { - xs = layer.forward(&xs, encoder_xs)? + xs = layer.forward(&xs, encoder_xs)?; } Ok(xs) } @@ -385,6 +387,7 @@ impl Model { #[derive(Debug, Clone)] pub struct MTModel { model: Model, + lm_head: Linear, final_logits_bias: Tensor, } @@ -393,8 +396,10 @@ impl MTModel { let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size); let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?; let model = Model::new(cfg, vb.pp("model"))?; + let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None); Ok(Self { model, + lm_head, final_logits_bias, }) } @@ -408,6 +413,10 @@ impl MTModel { } pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result { - self.model.decoder.forward(xs, Some(encoder_xs), 0) + self.model + .decoder + .forward(xs, Some(encoder_xs), 0)? + .apply(&self.lm_head)? + .broadcast_add(&self.final_logits_bias) } }