diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 8e4884b28d..f59fce4850 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -34,7 +34,10 @@ fn ceil_div(p: usize, q: usize) -> usize { } fn pad(p: usize, q: usize) -> usize { - ceil_div(p, q) * q + // Overallocate by q rather than just padding by q as this should pad the last row + // and we don't have enough information here to know how many elements to add :( + // ceil_div(p, q) * q + p + q } fn quantize_q8_1( @@ -439,7 +442,7 @@ impl QCudaStorage { } _ => crate::bail!("only f32 can be quantized"), }; - let src_len = src.len(); + let src_len = pad(src.len(), MATRIX_ROW_PADDING); let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; qcpu_storage.quantize(&src)?; diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index fa6eff51d3..1af21fd3a0 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( let actual_blocks = ys.len(); // Validate that the input is the right size - if expected_blocks != actual_blocks { + if actual_blocks < expected_blocks { crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!") }