Skip to content

Commit

Permalink
Going overbounds will break other kernels running from other threads.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Nov 6, 2023
1 parent 4d87305 commit cd68c96
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 7 additions & 9 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,15 @@ mod tests {

fn run_cos<T: Clone>(v: &[T], name: &str) -> Vec<T> {
let device = device();
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let option = metal::MTLResourceOptions::StorageModeManaged;
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
option,
options,
);
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, option);
let output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let library = device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library");
Expand All @@ -184,8 +183,6 @@ mod tests {
encoder.set_compute_pipeline_state(&pipeline);

encoder.set_bytes(0, 4, void_ptr(&dim));
// encoder.set_bytes(1, 4, void_ptr(&num_dims));
// encoder.set_bytes(2, 4, void_ptr(&info));

encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
Expand Down Expand Up @@ -239,8 +236,7 @@ mod tests {
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
// let options = MTLResourceOptions::StorageModeShared;
let options = metal::MTLResourceOptions::StorageModeManaged;
let options = MTLResourceOptions::StorageModeManaged;

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
Expand Down Expand Up @@ -284,6 +280,7 @@ mod tests {

let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
let result = outputs_buffer.read_to_vec::<f32>(output.len());
println!("Result {:?}", result.as_ptr());
assert_eq!(result, expected);
}

Expand All @@ -306,7 +303,7 @@ mod tests {
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = metal::MTLResourceOptions::StorageModeManaged;
let options = MTLResourceOptions::StorageModeManaged;

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
Expand Down Expand Up @@ -353,6 +350,7 @@ mod tests {
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
println!("Result {:?}", result.as_ptr());
assert_eq!(result, expected);
}

Expand Down
8 changes: 5 additions & 3 deletions candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ kernel void FN_NAME( \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
output[i] = TYPENAME(FN(input[i])); \
if (i > dim){ \
return; \
} \
output[i] = FN(input[i]); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
Expand All @@ -61,8 +64,7 @@ kernel void FN_NAME_STRIDED( \
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
for (size_t i = start; i < stop; i++) { \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
output[i] = 1; \
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
}

Expand Down

0 comments on commit cd68c96

Please sign in to comment.