diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 982e5ee11..b7efb93b3 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -9,7 +9,7 @@ use half::{bf16, f16}; use metal; use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::mps::{Float32, MPSDataType}; -use metal::MTLResourceOptions; +use metal::{MTLResourceOptions, Buffer}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -47,6 +47,14 @@ impl MetalDevice { pub fn id(&self) -> u64 { self.registry_id() } + + fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{ + let size = (element_count * dtype.size_in_bytes()) as u64; + self.device.new_buffer( + size, + MTLResourceOptions::empty(), + ) + } } #[derive(Debug, Clone)] @@ -106,11 +114,16 @@ impl BackendStorage for MetalStorage { todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype) } - fn unary_impl(&self, _: &Layout) -> Result { - // todo!() - // TODO - println!("TODO {:?}", B::NAME); - Ok(self.clone()) + fn unary_impl(&self, layout: &Layout) -> Result { + let device = self.device().clone(); + let dtype = self.dtype; + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + todo!("Implement the kernel calling"); + // device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype); + Ok(Self { buffer, device, dtype }) } fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { @@ -271,13 +284,10 @@ impl MetalStorage { let elem_count = b * m * n; match (self.dtype, rhs.dtype) { (DType::F32, DType::F32) => { + let out_buffer = self.device.new_buffer(elem_count, self.dtype); if b != 1 { println!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); - let out_buffer = self.device.new_buffer( - (elem_count * mem::size_of::()) as u64, - MTLResourceOptions::empty(), - ); return Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -286,20 +296,12 @@ impl MetalStorage { } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous()); - let out_buffer = self.device.new_buffer( - (elem_count * mem::size_of::()) as u64, - MTLResourceOptions::empty(), - ); return Ok(Self { buffer: out_buffer, device: self.device.clone(), dtype: self.dtype(), }); } - let out_buffer = self.device.new_buffer( - (elem_count * mem::size_of::()) as u64, - MTLResourceOptions::empty(), - ); let m: u64 = m.try_into().expect("usize should fit u64"); let n: u64 = n.try_into().expect("usize should fit u64"); let k: u64 = k.try_into().expect("usize should fit u64"); @@ -359,6 +361,7 @@ impl MetalStorage { } } + impl BackendDevice for MetalDevice { type Storage = MetalStorage; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 8b1378917..fdc14b957 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1 +1,138 @@ +use metal::{Buffer, Device, Function, Library, CompileOptions}; +use std::collections::HashMap; +use std::sync::RwLock; +static UNARY: &'static str = include_str!("unary.metal"); + +pub enum Error {} + +pub struct Kernels { + libraries: RwLock>, + funcs: RwLock>, +} + +impl Kernels { + pub fn new() -> Self { + let libraries = RwLock::new(HashMap::new()); + let funcs = RwLock::new(HashMap::new()); + Self { libraries, funcs } + } + pub fn call_unary( + &self, + device: &Device, + name: &str, + input: &Buffer, + output: &mut Buffer, + length: usize, + ) -> Result<(), Error> { + if let Some(func) = self + .funcs + .read() + .expect("Failed to acquire kernel lock") + .get(name) + { + call_unary(func, input, output, length); + } else { + let func = self + .libraries + .write() + .expect("Failed to acquire lock") + .entry("unary") + .or_insert_with(|| { + device + .new_library_with_source(UNARY, &CompileOptions::new()) + .expect("Failed to load unary library") + }) + .get_function(name, None) + .expect("Could not find unary function"); + self.funcs + .write() + .expect("Failed to acquire lock") + .insert(name.to_string(), func.clone()); + call_unary(&func, input, output, length); + } + Ok(()) + } +} + +fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) { + todo!("Call unary"); +} + +#[cfg(test)] +mod tests { + use super::*; + use metal::{ + ComputePipelineDescriptor, MTLResourceOptions, MTLResourceUsage, MTLSize, + }; + + fn approx(v: Vec, digits: i32) -> Vec{ + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() + } + + #[test] + fn cos() { + let v = vec![1.0f32, 2.0, 3.0]; + let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; + let device = Device::system_default().unwrap(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + (v.len() * core::mem::size_of::()) as u64, + option, + ); + let output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, option); + let library = device + .new_library_with_source(UNARY, &CompileOptions::new()) + .expect("Failed to load unary library"); + let func = library.get_function("cos", None).unwrap(); + let argument_encoder = func.new_argument_encoder(0); + let arg_buffer = device.new_buffer( + argument_encoder.encoded_length(), + MTLResourceOptions::empty(), + ); + argument_encoder.set_argument_buffer(&arg_buffer, 0); + argument_encoder.set_buffer(0, &input, 0); + argument_encoder.set_buffer(1, &output, 0); + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&arg_buffer), 0); + + encoder.use_resource(&input, MTLResourceUsage::Read); + encoder.use_resource(&output, MTLResourceUsage::Write); + + let width = 16; + + let thread_group_count = MTLSize { + width, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: (v.len() as u64 + width) / width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + let results = output.read_to_vec::(v.len()); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + } +} diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal new file mode 100644 index 000000000..3861b2f05 --- /dev/null +++ b/candle-metal-kernels/src/unary.metal @@ -0,0 +1,14 @@ +#include + +using namespace metal; + +struct Input { + device float *input; + device float *output; +}; + +kernel void cos(device Input& args [[ buffer(0) ]], uint index [[thread_position_in_grid]]) +{ + args.output[index] = cos(args.input[index]); +} +