Skip to content

Commit

Permalink
Reduce the overhead when cache is disabled (#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 9, 2024
1 parent 73fa2d4 commit 62757db
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 43 deletions.
45 changes: 20 additions & 25 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,40 @@
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import List

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch


class PolicyScheduler:
def __init__(
self,
policy,
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled.
def __init__(self, policy, tree_cache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs"

self.policy = policy
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache

def get_priority_queue(self, waiting_queue):
def calc_priority(self, waiting_queue: List[Req]):
if self.policy in ["lpm", "dfs-weight"]:
# Compute matched prefix length
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)

if self.policy == "lpm":
# longest prefix match
# Longest Prefix Match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.policy == "fcfs":
# first come first serve
return waiting_queue
pass
elif self.policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.policy == "random":
random.shuffle(waiting_queue)
return waiting_queue
elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
Expand All @@ -66,12 +62,13 @@ def get_priority_queue(self, waiting_queue):
node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight)

q = []
waiting_queue.clear()
self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
self.tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_policy: {self.policy}")

Expand Down Expand Up @@ -139,8 +136,6 @@ def _prefill_one_req(
self.log_input_tokens += extend_input_len

def add_inflight_req(self, req: Req):
req.input_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ def __init__(self, rid, origin_input_text, origin_input_ids):
def finished(self) -> bool:
return self.finished_reason is not None

def init_next_round_input(self):
self.input_ids = self.origin_input_ids + self.output_ids
self.extend_input_len = len(self.input_ids) - len(self.prefix_indices)

def adjust_max_prefix_ids(self):
self.input_ids = self.origin_input_ids + self.output_ids
input_len = len(self.input_ids)
max_prefix_len = input_len

Expand Down
22 changes: 4 additions & 18 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,7 @@ def __init__(
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler(
self.schedule_policy,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
self.tree_cache,
)
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool

Expand Down Expand Up @@ -373,17 +367,8 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
if running_bs >= self.max_running_requests:
return None

# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
# NOTE: the prefix_indices must always be aligned with last_node
req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
rid=req.rid, key=req.adjust_max_prefix_ids()
)
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)

# Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
self.scheduler.calc_priority(self.waiting_queue)

adder = PrefillAdder(
self.tree_cache,
Expand All @@ -397,12 +382,13 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:

has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input()
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)

for req in self.waiting_queue:

req.init_next_round_input()
res = adder.add_one_req(req)
if (
not res
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def evict(self, num_tokens, evict_callback):
heapq.heappush(leaves, x.parent)

def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0

delta = 0
while node != self.root_node:
if node.lock_ref == 0:
Expand All @@ -179,6 +182,9 @@ def inc_lock_ref(self, node: TreeNode):
return delta

def dec_lock_ref(self, node: TreeNode):
if self.disable:
return 0

delta = 0
while node != self.root_node:
if node.lock_ref == 1:
Expand Down

0 comments on commit 62757db

Please sign in to comment.