Skip to content

Commit

Permalink
Merge branch 'main' of github.com:sgl-project/sglang
Browse files Browse the repository at this point in the history
  • Loading branch information
Chayenne committed Oct 26, 2024
2 parents e232132 + 2b80978 commit c56edf2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
build-and-deploy:
needs: execute-notebooks
if: github.repository == 'sgl-project/sglang'
runs-on: ubuntu-latest
runs-on: 1-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v3
Expand Down
17 changes: 11 additions & 6 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)


class SchedulePolicy:
Expand Down Expand Up @@ -146,7 +148,7 @@ def __init__(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
CLIP_MAX_NEW_TOKENS_ESTIMATION,
)
* self.new_token_ratio
for r in running_batch.reqs
Expand Down Expand Up @@ -186,7 +188,7 @@ def add_inflight_req(self, req: Req):
len(req.prefix_indices),
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
if not truncated
else 0
),
Expand Down Expand Up @@ -258,7 +260,7 @@ def add_req_state(r, insert_sort=False):
self._prefill_one_req(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
)
else:
# Chunked prefill
Expand All @@ -276,7 +278,7 @@ def add_one_req(self, req: Req):
return self.add_one_req_ignore_eos(req)

total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
Expand All @@ -302,7 +304,10 @@ def add_one_req(self, req: Req):
self._prefill_one_req(
prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
min(
req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKENS_ESTIMATION,
),
)
else:
# Chunked prefill
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,15 @@ def __init__(self, model_runner: "ModelRunner"):
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder

# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.max_cuda_graph_bs
]
self.compile_bs = (
[
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class ServerArgs:
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
max_cuda_graph_bs: int = 160
torchao_config: str = ""
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
Expand Down Expand Up @@ -624,6 +625,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.max_torch_compile_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--max-cuda-graph-bs",
type=int,
default=ServerArgs.max_cuda_graph_bs,
help="Set the maximum batch size for cuda graph.",
)
parser.add_argument(
"--torchao-config",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_large_max_new_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=("--max-total-token", "1024", "--context-len", "8192"),
env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ},
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
return_stdout_stderr=(cls.stdout, cls.stderr),
)
cls.base_url += "/v1"
Expand Down

0 comments on commit c56edf2

Please sign in to comment.