Skip to content

Commit

Permalink
Skip unnecessary penalizer (#1707)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 19, 2024
1 parent bc12d40 commit 2bcfba1
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 75 deletions.
25 changes: 15 additions & 10 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,11 @@ def prepare_for_extend(self, vocab_size: int):
assert seq_len - pre_len == req.extend_input_len

if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
req.prefix_indices
)

self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
)

Expand All @@ -535,10 +535,15 @@ def prepare_for_extend(self, vocab_size: int):
pt += req.extend_input_len

# Set fields
with out_cache_loc.device:
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)

self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
Expand Down Expand Up @@ -782,8 +787,8 @@ def filter_batch(
return

self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(
keep_indices, dtype=torch.int32, device=self.seq_lens.device
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
Expand Down
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
nccl_port=port_args.nccl_port,
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.device = self.tp_worker.device

# Get token and memory info from the model worker
(
Expand Down Expand Up @@ -758,9 +759,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(
len(next_token_ids), device=next_token_ids.device
),
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
)
Expand Down Expand Up @@ -828,7 +827,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
# Move logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def alloc(self, need_size: int):
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]

return select_index.to(self.device)
return select_index.to(self.device, non_blocking=True)

def free(self, free_index: torch.Tensor):
if self.is_not_in_free_group:
Expand Down
29 changes: 13 additions & 16 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,22 @@ def init_new(

# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.tensor(
np.concatenate(
[
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
dtype=torch.int64,
device=device,
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)

ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device
)
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_seq_lens
Expand Down
37 changes: 28 additions & 9 deletions python/sglang/srt/sampling/penaltylib/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ def __init__(

self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}

is_required = False
for penalizer in self.penalizers.values():
penalizer.prepare_if_required()
pen_is_required = penalizer.prepare_if_required()
is_required |= pen_is_required
self.is_required = is_required

self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
if self.is_required:
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)

def reqs(self):
return self.batch.reqs
Expand Down Expand Up @@ -79,6 +83,9 @@ def cumulate_output_tokens(
Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
"""
if not self.is_required:
return

token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)

for penalizer in self.penalizers.values():
Expand All @@ -95,6 +102,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
if not self.is_required:
return

for penalizer in self.penalizers.values():
logits = penalizer.apply(logits)

Expand All @@ -112,10 +122,16 @@ def filter(
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
if not self.is_required:
return

empty_indices = len(indices_to_keep) == 0

is_required = False
for penalizer in self.penalizers.values():
if not penalizer.is_required() or empty_indices:
tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required
if not tmp_is_required or empty_indices:
penalizer.teardown()
else:
# create tensor index only when it's needed
Expand All @@ -128,6 +144,7 @@ def filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
self.is_required = is_required

def merge(self, their: "BatchedPenalizerOrchestrator"):
"""
Expand All @@ -140,11 +157,10 @@ def merge(self, their: "BatchedPenalizerOrchestrator"):
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""
if self.vocab_size != their.vocab_size:
raise ValueError(
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
)
if not self.is_required and not their.is_required:
return

self.is_required |= their.is_required
for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
Expand Down Expand Up @@ -250,6 +266,9 @@ def prepare(self):
def prepare_if_required(self):
if self.is_required():
self.prepare()
return True
else:
return False

def teardown(self):
if self.is_prepared():
Expand Down
51 changes: 28 additions & 23 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,24 @@ def from_schedule_batch(
disable_penalizer: bool,
):
reqs = batch.reqs
with batch.input_ids.device:
temperatures = torch.tensor(
device = batch.input_ids.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
).view(-1, 1)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
.view(-1, 1)
.to(device, non_blocking=True)
)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
).to(device, non_blocking=True)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)

ret = cls(
temperatures=temperatures,
Expand All @@ -80,7 +84,7 @@ def from_schedule_batch(
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
# handle {filter_batch()} and {merge_batch()} cases as well.
if disable_penalizer:
ret.penalizer_orchestrator = None
else:
Expand Down Expand Up @@ -112,19 +116,20 @@ def update_penalties(self):
self.linear_penalties = None

for penalizer in self.penalizer_orchestrator.penalizers.values():
if not penalizer.is_prepared():
continue

if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)

def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
Expand Down
28 changes: 16 additions & 12 deletions python/sglang/test/srt/sampling/penaltylib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,20 @@ def test_prepare(self):
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)

actual = orchestrator.apply(
torch.ones(
size=(len(case.test_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
original = torch.ones(
size=(len(case.test_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
actual = orchestrator.apply(original.clone())
expected = torch.cat(
tensors=[
subject.steps[0].expected_logits
for subject in case.test_subjects
],
)
if actual is None:
actual = original
torch.testing.assert_close(
actual=actual,
expected=expected,
Expand Down Expand Up @@ -226,6 +227,8 @@ def test_filter(self):
device=self.device,
)
)
if actual_logits is None:
continue
filtered_expected_logits = torch.cat(
tensors=[
subject.steps[0].expected_logits
Expand Down Expand Up @@ -317,19 +320,20 @@ def test_cumulate_apply_repeat(self):
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)

actual_logits = orchestrator.apply(
torch.ones(
size=(len(filtered_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
original = torch.ones(
size=(len(filtered_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
actual_logits = orchestrator.apply(original.clone())
filtered_expected_logits = torch.cat(
tensors=[
subject.steps[i].expected_logits
for subject in filtered_subjects
],
)
if actual_logits is None:
actual_logits = original
torch.testing.assert_close(
actual=actual_logits,
expected=filtered_expected_logits,
Expand Down

0 comments on commit 2bcfba1

Please sign in to comment.