Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check user-specified model_max_len with hf derived max_model_len #1778

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
limitations under the License.
"""

import logging
import os
from enum import IntEnum, auto
from typing import Optional

from transformers import PretrainedConfig

from sglang.srt.hf_transformers_utils import get_config, get_context_length

logger = logging.getLogger(__name__)


class AttentionArch(IntEnum):
MLA = auto()
Expand All @@ -46,10 +50,29 @@ def __init__(
model_override_args=model_override_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
)

if context_length is not None:
self.context_len = context_length
if context_length > derived_context_len:
if allow_long_context:
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
)
self.context_len = context_length
else:
raise ValueError(
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
)
else:
self.context_len = context_length
else:
self.context_len = get_context_length(self.hf_text_config)
self.context_len = derived_context_len

# Unify the config keys for hf_text_config
self.head_dim = getattr(
Expand Down
Loading