diff --git a/python/mlc_llm/model/cohere/cohere_model.py b/python/mlc_llm/model/cohere/cohere_model.py index 180c60ba13..8fd7c349a3 100644 --- a/python/mlc_llm/model/cohere/cohere_model.py +++ b/python/mlc_llm/model/cohere/cohere_model.py @@ -4,7 +4,7 @@ """ import dataclasses -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from tvm import te, tir from tvm.relax.frontend import nn @@ -32,6 +32,7 @@ class CohereConfig(ConfigBase): # pylint: disable=too-many-instance-attributes num_key_value_heads: int intermediate_size: int layer_norm_eps: float + use_qk_norm: bool position_embedding_base: int = 0 context_window_size: int = 0 prefill_chunk_size: int = 0 @@ -112,7 +113,7 @@ def forward(self, x): # pylint: disable=invalid-name,missing-docstring -class CohereAttention(nn.Module): +class CohereAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: CohereConfig): self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards assert config.num_attention_heads % config.tensor_parallel_shards == 0, ( @@ -124,7 +125,17 @@ def __init__(self, config: CohereConfig): f"num_attention_heads({config.num_key_value_heads}) " "must be divisible by tensor_parallel_shards" ) + self.head_dim = config.head_dim + self.use_qk_norm = config.use_qk_norm + + if self.use_qk_norm: + self.q_norm = CohereNorm( + hidden_size=[self.num_q_heads, self.head_dim], eps=config.layer_norm_eps + ) + self.k_norm = CohereNorm( + hidden_size=[self.num_key_value_heads, self.head_dim], eps=config.layer_norm_eps + ) self.qkv_proj = nn.Linear( in_features=config.hidden_size, @@ -139,6 +150,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: # QKV Projection qkv = self.qkv_proj(hidden_states) qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + + if self.use_qk_norm: + q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2) + q = self.q_norm(q) + k = self.k_norm(k) + qkv = op.concat([q, k, v], dim=2) + # Attention output = op.reshape( paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), @@ -188,17 +206,25 @@ def _apply_parallel_residual(self, mlp_out, residual): class CohereNorm(nn.Module): def __init__( - self, normalized_shape: int, eps: float = 1e-5, dtype: Optional[str] = None + self, + hidden_size: Optional[Union[int, list]] = None, + eps: float = 1e-5, + dtype: Optional[str] = None, ) -> None: - super().__init__() - self.normalized_shape = normalized_shape + self.hidden_size = hidden_size self.eps = eps - self.weight = nn.Parameter((normalized_shape,), dtype=dtype) + if isinstance(hidden_size, int): + normalized_shape = [hidden_size] + elif isinstance(hidden_size, list): + normalized_shape = hidden_size + else: + raise ValueError("hidden_size must be an int or a list of ints") + self.weight = nn.Parameter(normalized_shape, dtype=dtype) def forward(self, x: Tensor) -> Tensor: return op.layer_norm( x, - normalized_shape=self.normalized_shape, + normalized_shape=self.hidden_size, weight=self.weight, bias=None, eps=self.eps,