Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add use_qk_norm option for Cohere model #2877

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions python/mlc_llm/model/cohere/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import dataclasses
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from tvm import te, tir
from tvm.relax.frontend import nn
Expand Down Expand Up @@ -32,6 +32,7 @@ class CohereConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
num_key_value_heads: int
intermediate_size: int
layer_norm_eps: float
use_qk_norm: bool
position_embedding_base: int = 0
context_window_size: int = 0
prefill_chunk_size: int = 0
Expand Down Expand Up @@ -112,7 +113,7 @@ def forward(self, x):
# pylint: disable=invalid-name,missing-docstring


class CohereAttention(nn.Module):
class CohereAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: CohereConfig):
self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards
assert config.num_attention_heads % config.tensor_parallel_shards == 0, (
Expand All @@ -124,7 +125,17 @@ def __init__(self, config: CohereConfig):
f"num_attention_heads({config.num_key_value_heads}) "
"must be divisible by tensor_parallel_shards"
)

self.head_dim = config.head_dim
self.use_qk_norm = config.use_qk_norm

if self.use_qk_norm:
self.q_norm = CohereNorm(
hidden_size=[self.num_q_heads, self.head_dim], eps=config.layer_norm_eps
)
self.k_norm = CohereNorm(
hidden_size=[self.num_key_value_heads, self.head_dim], eps=config.layer_norm_eps
)

self.qkv_proj = nn.Linear(
in_features=config.hidden_size,
Expand All @@ -139,6 +150,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
# QKV Projection
qkv = self.qkv_proj(hidden_states)
qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))

if self.use_qk_norm:
q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2)
q = self.q_norm(q)
k = self.k_norm(k)
qkv = op.concat([q, k, v], dim=2)

# Attention
output = op.reshape(
paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads),
Expand Down Expand Up @@ -188,17 +206,25 @@ def _apply_parallel_residual(self, mlp_out, residual):

class CohereNorm(nn.Module):
def __init__(
self, normalized_shape: int, eps: float = 1e-5, dtype: Optional[str] = None
self,
hidden_size: Optional[Union[int, list]] = None,
eps: float = 1e-5,
dtype: Optional[str] = None,
) -> None:
super().__init__()
self.normalized_shape = normalized_shape
self.hidden_size = hidden_size
self.eps = eps
self.weight = nn.Parameter((normalized_shape,), dtype=dtype)
if isinstance(hidden_size, int):
normalized_shape = [hidden_size]
elif isinstance(hidden_size, list):
normalized_shape = hidden_size
else:
raise ValueError("hidden_size must be an int or a list of ints")
self.weight = nn.Parameter(normalized_shape, dtype=dtype)

def forward(self, x: Tensor) -> Tensor:
return op.layer_norm(
x,
normalized_shape=self.normalized_shape,
normalized_shape=self.hidden_size,
weight=self.weight,
bias=None,
eps=self.eps,
Expand Down
Loading