Skip to content

Commit

Permalink
Fix cuda bug (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
meenchen authored Aug 23, 2023
1 parent ac0a11c commit 1fed7a3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions transformer/src/nn_modules/cuda/Int4llamaAttention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(const struct Int4ll
// PROFILE_START(profile_name + "::RotaryPosEmb_cuda_forward");
dim3 grid(num_heads, 1, 1);
dim3 block(sqlen, 1, 1);
// RotaryPosEmb_cuda_forward<<<grid, block>>>(query_states, key_states, this->rotary_pos_emb.cos, this->rotary_pos_emb.sin, start_idx, sqlen);
RotaryPosEmb_cuda_forward<<<grid, block>>>(query_states, key_states, this->rotary_pos_emb.cos, this->rotary_pos_emb.sin, start_idx, sqlen);

const int shared_memory_size = 2 * this->embed_dim * sizeof(half);
RotaryPosEmb_cuda_forward_new<<<grid, block, shared_memory_size>>>(query_states, key_states, this->rotary_pos_emb.cos, this->rotary_pos_emb.sin, start_idx, sqlen);
// const int shared_memory_size = 2 * this->embed_dim * sizeof(half);
// RotaryPosEmb_cuda_forward_new<<<grid, block, shared_memory_size>>>(query_states, key_states, this->rotary_pos_emb.cos, this->rotary_pos_emb.sin, start_idx, sqlen);

// const int threads_per_block = 1024; // This value can be tuned for best performance.
// const int blocks_per_grid = (num_heads * sqlen + threads_per_block - 1) / threads_per_block;
Expand Down

0 comments on commit 1fed7a3

Please sign in to comment.