From 62757db6f0f09a6dff15b1ee1ac3029602951509 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Fri, 9 Aug 2024 16:36:57 -0700 Subject: [PATCH] Reduce the overhead when cache is disabled (#1010) --- .../sglang/srt/managers/policy_scheduler.py | 45 +++++++++---------- python/sglang/srt/managers/schedule_batch.py | 5 +++ python/sglang/srt/managers/tp_worker.py | 22 ++------- python/sglang/srt/mem_cache/radix_cache.py | 6 +++ 4 files changed, 35 insertions(+), 43 deletions(-) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 30a009c2e6..a05ba9c9c3 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -18,44 +18,40 @@ import random from collections import defaultdict from contextlib import contextmanager +from typing import List from sglang.srt.managers.schedule_batch import Req, ScheduleBatch class PolicyScheduler: - def __init__( - self, - policy, - max_running_seqs, - max_prefill_num_tokens, - max_total_num_tokens, - tree_cache, - ): - if tree_cache.disable and policy == "lpm": - # LMP is meaningless when the tree cache is disabled. + def __init__(self, policy, tree_cache): + if tree_cache.disable and policy in ["lpm", "dfs-weight"]: + # LPM and DFS-weight is meaningless when the tree cache is disabled. policy = "fcfs" self.policy = policy - self.max_running_seqs = max_running_seqs - self.max_prefill_num_tokens = max_prefill_num_tokens - self.max_total_num_tokens = max_total_num_tokens self.tree_cache = tree_cache - def get_priority_queue(self, waiting_queue): + def calc_priority(self, waiting_queue: List[Req]): + if self.policy in ["lpm", "dfs-weight"]: + # Compute matched prefix length + for r in waiting_queue: + # NOTE: the prefix_indices must always be aligned with last_node + r.prefix_indices, r.last_node = self.tree_cache.match_prefix( + rid=r.rid, key=r.adjust_max_prefix_ids() + ) + if self.policy == "lpm": - # longest prefix match + # Longest Prefix Match waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) - return waiting_queue elif self.policy == "fcfs": # first come first serve - return waiting_queue + pass elif self.policy == "lof": # longest output first waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) - return waiting_queue elif self.policy == "random": random.shuffle(waiting_queue) - return waiting_queue elif self.policy == "dfs-weight": last_node_to_reqs = defaultdict(list) for req in waiting_queue: @@ -66,12 +62,13 @@ def get_priority_queue(self, waiting_queue): node_to_weight[node] = len(last_node_to_reqs[node]) self.calc_weight(self.tree_cache.root_node, node_to_weight) - q = [] + waiting_queue.clear() self.get_dfs_priority( - self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q + self.tree_cache.root_node, + node_to_weight, + last_node_to_reqs, + waiting_queue, ) - assert len(q) == len(waiting_queue) - return q else: raise ValueError(f"Unknown schedule_policy: {self.policy}") @@ -139,8 +136,6 @@ def _prefill_one_req( self.log_input_tokens += extend_input_len def add_inflight_req(self, req: Req): - req.input_ids = req.origin_input_ids + req.output_ids - req.extend_input_len = len(req.input_ids) - len(req.prefix_indices) truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len] diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2489abd5de..278ed006ef 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -164,7 +164,12 @@ def __init__(self, rid, origin_input_text, origin_input_ids): def finished(self) -> bool: return self.finished_reason is not None + def init_next_round_input(self): + self.input_ids = self.origin_input_ids + self.output_ids + self.extend_input_len = len(self.input_ids) - len(self.prefix_indices) + def adjust_max_prefix_ids(self): + self.input_ids = self.origin_input_ids + self.output_ids input_len = len(self.input_ids) max_prefix_len = input_len diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 0228073c77..c668977106 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -165,13 +165,7 @@ def __init__( disable=server_args.disable_radix_cache, ) self.tree_cache_metrics = {"total": 0, "hit": 0} - self.scheduler = PolicyScheduler( - self.schedule_policy, - self.max_running_requests, - self.max_prefill_tokens, - self.max_total_num_tokens, - self.tree_cache, - ) + self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache) self.req_to_token_pool = self.model_runner.req_to_token_pool self.token_to_kv_pool = self.model_runner.token_to_kv_pool @@ -373,17 +367,8 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: if running_bs >= self.max_running_requests: return None - # Compute matched prefix length - for req in self.waiting_queue: - req.input_ids = req.origin_input_ids + req.output_ids - # NOTE: the prefix_indices must always be aligned with last_node - req.prefix_indices, req.last_node = self.tree_cache.match_prefix( - rid=req.rid, key=req.adjust_max_prefix_ids() - ) - req.extend_input_len = len(req.input_ids) - len(req.prefix_indices) - # Get priority queue - self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) + self.scheduler.calc_priority(self.waiting_queue) adder = PrefillAdder( self.tree_cache, @@ -397,12 +382,13 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: has_inflight = self.current_inflight_req is not None if self.current_inflight_req is not None: + self.current_inflight_req.init_next_round_input() self.current_inflight_req = adder.add_inflight_req( self.current_inflight_req ) for req in self.waiting_queue: - + req.init_next_round_input() res = adder.add_one_req(req) if ( not res diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index c238120492..05cbb2c926 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -169,6 +169,9 @@ def evict(self, num_tokens, evict_callback): heapq.heappush(leaves, x.parent) def inc_lock_ref(self, node: TreeNode): + if self.disable: + return 0 + delta = 0 while node != self.root_node: if node.lock_ref == 0: @@ -179,6 +182,9 @@ def inc_lock_ref(self, node: TreeNode): return delta def dec_lock_ref(self, node: TreeNode): + if self.disable: + return 0 + delta = 0 while node != self.root_node: if node.lock_ref == 1: