diff --git a/.github/workflows/check_homepage_build.yaml b/.github/workflows/check_homepage_build.yaml index fe2d8bec..6f3609b6 100644 --- a/.github/workflows/check_homepage_build.yaml +++ b/.github/workflows/check_homepage_build.yaml @@ -33,4 +33,4 @@ jobs: - name: Install other homepage dependencies run: pip install -r docs/requirements.txt - name: Build homepage - run: mkdocs build --verbose --strict + run: mkdocs build --verbose diff --git a/docs/perseus/index.md b/docs/perseus/index.md new file mode 100644 index 00000000..7c79b7d8 --- /dev/null +++ b/docs/perseus/index.md @@ -0,0 +1,89 @@ +# Perseus: Energy Scheduling in Large Model Training + +!!! Warning + Perseus is under active development, and breaking changes may happen. + Currently, we have all the low-level APIs in place, but it's not a turnkey solution yet. + This document always reflects the master `HEAD`. + +## Overview + +Perseus finds the training time--energy Pareto frontier of large model training. +Users can pick any point on the frontier -- be it minimum time, minimum energy, or something in the middle, depending on the training deadline. + +Large model training requires the distribution of work to multiple GPUs. +The core observation of Perseus is that work cannot be perfectly split and balanced across every GPU; some GPUs have more work to do and some less. +GPUs with smaller amounts of work finish before GPUs with more amounts of work, but ultimately training throughput is bound by GPUs with the most amount of work. +In other words, GPUs with lighter load are running unnecessarily fast and wasting energy (i.e., there is **energy bloat**). + +We reduce enregy bloat by controlling the execution speed of each pipeline instruction (forward and backward) in each stage by controlling the GPU's frequency in a fine-grained manner. +We call the assignment of a GPU frequency to each pipeline instruction *frequency plan*, and Perseus gives you **every Pareto-optimal frequency plan** that you can choose any point on the iteration time--energy Pareto frontier. +These plans include frequency plans that do not make training any slower compared to not using Perseus at all, but yield free energy savings. +If you have a bit more leeway as to when training should finish (e.g., You're good as long as training finishes by tomorrow morning), you can pick the frequency plan that slows down training by a couple percentages and save more energy. +Our core algorithm, implemented as a separate library called [`lowtime`](https://github.com/ml-energy/lowtime), **provably guarantees** that for any time deadline, energy consumption is minimal. + +## How it's done + +Currently it's a three-step process: + +1. **Profile**: Profile the computation time and energy consumption of the forward and backward instructions in *each stage* and *each GPU frequency*. +2. **Optimize**: Use [`lowtime`](https://github.com/ml-energy/lowtime) to generate all Pareto-optimal frequency plans. +3. **Choose and start training**: Among all the frequency plans generated by `lowtime`, choose the one that suits your use case. + +We have a reference integration with the large model training framework [Merak](https://github.com/ml-energy/merak-zeus), which supports 3D parallelism and automatically tracing and partitioning Hugging Face models. +We've smoothed out some rough edges, integrated Zeus and Perseus, and maintained example scripts for GPT3, BERT, and Wide-ResNet (pretty much any `torchvision` model). + +You don't have to be tied to Merak. +If you have your own training framework, and you can integrate Perseus following [our integration guide](integrating.md). + +### Profile + +In order to run our optimization algorithm, we need the time & energy profiling information of the forward and backward instruction in each stage for every GPU frequency. +The CSV file should look like this for a 4-stage pipeline: + +```csv +stage,instruction,frequency,time,energy +0,forward,1740,0.09373254776000976,28.4944 +0,forward,1725,0.09390360514322917,28.434366666666666 +0,forward,1710,0.09381131331125896,28.288966666666667 +... +0,backward,1740,0.24533510557810465,69.5691 +0,backward,1725,0.24538559118906658,69.2552 +0,backward,1710,0.24548352559407552,68.89453333333334 +... +3,backward,690,0.4184921979904175,68.12243333333333 +3,backward,675,0.42459266185760497,68.77603333333334 +3,backward,660,0.4306272824605306,69.39623333333334 +``` + +Since different frameworks and model implementations will have different performance, it's best to obtain these profiling results on the framework and model you'll be using. +That being said, you can obtain this profiling information in however way you want as long as they have all the columns in the reference CSV file above. +But as a reference, we have implemented an automatic profiler in Merak. +Please refer to the [examples](https://github.com/ml-energy/merak-zeus/tree/main/examples) directory in Merak for profiling instructions. + +!!! Tip + As you profile the time and energy consumption of an instruction, you will scan down from the highest to the lowest frequency. + However, as you lower the GPU's frequency, both time and energy will start to inflate after some point. + In other words, those frequencies take more time **and** energy and are simply inefficient (i.e., Pareto-suboptimal), so we won't be running anything with those frequencies. + Therefore, you actually don't need to profile time and energy for *every* frequency. + A good heuristic is to scan from higher frequencies to lower ones, and once energy consumption increases more than five *consecutive* times, just stop there. + +### Optimize + +With the CSV file that holds profiling results, you can use `lowtime` to generate all Pareto-optimal frequency plans. + +See [`examples/perseus`](https://github.com/ml-energy/zeus/tree/master/examples/perseus) to find the script `run_optimization.py`. + +### Choose and start training + +Running `lowtime` optimization will produce a set of frequency assignment files (`freqs_pipeline_%05d.py`). +Each file is also annotated with estimates for time and cost. +The larger the number, the shorter the expected iteration time. + +Then, start the Perseus server and plug in the frequency plan you chose: + +```console +$ docker exec -it merak-zeus bash +# PERSEUS_SCHEDULER_ARGS='{"solution_path": "path/to/freqs_pipeline_%05d.py"}' uvicorn zeus.optimizer.perseus.server.router:app --port 7787 +``` + +When you run training (with the same `run.sh` but without `--profile true`), the [`PerseusOptimizer`][zeus.optimizer.perseus.optimizer.PerseusOptimizer] integrated into your training framework will automatically talk with the Perseus server to figure out the right GPU frequency to set for the upcoming pipeline instruction and transparently set the GPU's frequency. diff --git a/docs/perseus/integrating.md b/docs/perseus/integrating.md new file mode 100644 index 00000000..0187ad6a --- /dev/null +++ b/docs/perseus/integrating.md @@ -0,0 +1,4 @@ +# Integrating `PerseusOptimizer` + +This page is currently under construction. +We have a reference integration with [Merak](https://github.com/ml-energy/merak-zeus). diff --git a/examples/perseus/run_optimization.py b/examples/perseus/run_optimization.py new file mode 100644 index 00000000..242d3f99 --- /dev/null +++ b/examples/perseus/run_optimization.py @@ -0,0 +1,245 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example script of running Perseus energy scheduling.""" + +from __future__ import annotations + +import time +import itertools +import logging +from pathlib import Path +from typing import Type +from collections import defaultdict +from dataclasses import dataclass + +import tyro +import pandas as pd +import networkx as nx +import matplotlib.pyplot as plt + +from lowtime.operation import ( + CandidateExecutionOptions, + OperationSpec, + ExecutionOption, +) +from lowtime.cost_model import ExponentialModel +from lowtime.perseus.instruction import ( + Instruction, + Forward, + Backward, + forward_dep, + backward_dep, +) +from lowtime.solver import PhillipsDessouky +from lowtime.graph_utils import add_source_node, add_sink_node, DependencyResolver +from lowtime.perseus.schedule import Synchronous1F1B +from lowtime.perseus.visualizer import PipelineVisualizer, ANNOTATE_ARGS, LINE_ARGS + +logger = logging.getLogger() + + +@dataclass +class Args: + # Path to instruction profile results + inst_profile: str + # GPU power consumption while blocking on P2P communication, in Watts + p2p_power: float = 70.0 + # Directory to output results + output_dir: Path + # Number of microbatchs + num_mbs: int + # Number of stages + num_stages: int + # Interval to draw the state of the pipeline + plot_interval: int = 100 + # The unit of reduction for each iteration, in seconds + unit_time: float = 0.001 + + +def main(args: Args) -> None: + """Perseus time-cost tradeoff optimization.""" + # Setup logging and output. + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + log_path = output_dir / "job.log" + + logging.basicConfig( + format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + handlers=[logging.FileHandler(log_path, mode="a"), logging.StreamHandler()], + ) + logger.info("Arguments: %s", args) + + # Instruction offline profiling results. + inst_df = pd.read_csv(args.inst_profile) + + #################### + # Execution Option # + #################### + # Construct the OperationSpec object of each pipeline instruction in each stage. + op_spec_map: dict[int, dict[Type[Instruction], OperationSpec]] = defaultdict(dict) + for instruction in [Forward, Backward]: + inst_name = instruction.__name__ + for stage_id in range(args.num_stages): + logger.info("Processing %s stage %d", inst_name, stage_id) + options = [] + _df = inst_df.query( + f"stage == {stage_id} and instruction == '{inst_name.lower()}'" + ) + for _, row in _df.iterrows(): + row = row.to_dict() + options.append( + ExecutionOption[int]( + real_time=row["time"], + unit_time=args.unit_time, + cost=row["energy"], + knob=int(row["frequency"]), + ) + ) + + # Get the Preto frontier, quantize time, and deduplicate time. + cand_options = CandidateExecutionOptions[int](options=options) + + # Map the cost to be effective computation energy. + # Everything from now on is in terms of effective energy. + for option in cand_options.options: + option.cost -= args.p2p_power * option.quant_time * option.unit_time + + # Fit the cost model. + model = ExponentialModel(cand_options) + + # Draw the cost model. + fig, ax = plt.subplots(figsize=(8, 8), tight_layout=True) + model.draw(ax, cand_options) + fig.savefig(f"{output_dir}/{inst_name.lower()}_{stage_id}.png") + + # Initialize the operation spec. + op_spec = OperationSpec[int](options=cand_options, cost_model=model) + op_spec_map[stage_id][instruction] = op_spec + + #################### + # DAG construction # + #################### + dag = nx.DiGraph() + + # Generate and add all instructions to the DAG. + # Reserve 0 for dummy source and 1 for dummy sink. + node_id = 2 + instructions: list[list[Instruction]] = [] + for stage_id in range(args.num_stages): + # Generate instructions for each stage. + stage_insts: list[Instruction] = [] + stage_node_ids: list[int] = [] + for inst in Synchronous1F1B( + num_stages=args.num_stages, + num_micro_batches=args.num_mbs, + stage_id=stage_id, + operation_spec_map=op_spec_map[stage_id], + ): + dag.add_node(node_id, op=inst) + stage_insts.append(inst) + stage_node_ids.append(node_id) + node_id += 1 + instructions.append(stage_insts) + + # Add dependencies between adjacent instructions in the same stage. + for node_id1, node_id2 in zip(stage_node_ids, stage_node_ids[1:]): + dag.add_edge(node_id1, node_id2) + + # Add dependencies between dependent pipeline instructions. + insts = dag.nodes(data=True) + resolver = DependencyResolver( + dependency_rules=[forward_dep, backward_dep], + node_type=Instruction, + ) + for (id1, data1), (id2, data2) in itertools.product(insts, insts): + if resolver.is_dependent(data1["op"], data2["op"]): + dag.add_edge(id1, id2) + + # Add source and sink nodes. + add_source_node(dag, 0) + add_sink_node(dag, 1) + dag.graph["source_node"] = 0 + dag.graph["sink_node"] = 1 + + ################################### + # Time-cost tradeoff optimization # + ################################### + def annotation_hook(inst: Instruction) -> str: + return f"{type(inst).__name__[0]}\n{inst.micro_batch_id}" + + def draw(dag: nx.DiGraph, iteration: int, xlim: int) -> None: + ANNOTATE_ARGS[Forward]["fontsize"] = 11.0 + ANNOTATE_ARGS[Backward]["fontsize"] = 11.0 + ANNOTATE_ARGS[Forward]["color"] = "black" + ANNOTATE_ARGS[Backward]["color"] = "black" + LINE_ARGS["linewidth"] = 3.0 + + fig, ax = plt.subplots(figsize=(args.num_mbs, 4), tight_layout=True) + + vis = PipelineVisualizer(dag) + vis.draw( + ax, + draw_time_axis=True, + p2p_power=args.p2p_power, + annotation_hook=annotation_hook, + power_color="RdBu_r", + normalizer_range=(-200, 550), + ) + vis.draw_critical_path(ax) + + # Fix xlim so that we can visually see the pipeline width shrink. + ax.set_xlim(0.0, xlim) + ax.set_title(f"Iteration {iteration:4d}") + fig.savefig(f"{output_dir}/pipeline_{iteration:05d}.png") + plt.close(fig) + + solver = PhillipsDessouky(dag) + + draw_xlim = None + iteration = 0 + for iteration, result in enumerate(solver.run()): + # Maybe draw the pipeline. + if iteration % args.plot_interval == 0: + if draw_xlim is None: + draw_xlim = int(result.real_time) + 1 + draw(dag, iteration, draw_xlim) + + # Write the frequency assignment Python file. + f = open(args.output_dir / f"freqs_pipeline_{iteration:05d}.py", "w") + f.write("[\n") + for stage_id, stage_insts in enumerate(instructions): + stage_freq: list[tuple[str, int]] = [] + for inst in stage_insts: + stage_freq.append((type(inst).__name__.lower(), inst.assigned_knob)) + f.write(f"{stage_freq},\n") + f.write("]\n") + + iter_str = f"# Iteration {iteration} " + real_cost = result.cost + args.num_stages + result.real_time * args.p2p_power + f.write(iter_str + f"cost change: {result.cost_change}\n") + f.write(iter_str + f"total cost: {result.cost}\n") + f.write(iter_str + f"total cost with P2P: {real_cost}\n") + + assert draw_xlim is not None + draw(dag, iteration, draw_xlim) + + +if __name__ == "__main__": + args = tyro.cli(Args) + + start_time = time.time() + main(args) + logger.info("Total time: %.2fs", time.time() - start_time) diff --git a/examples/perseus/wide-resnet.gif b/examples/perseus/wide-resnet.gif new file mode 100644 index 00000000..5a820682 Binary files /dev/null and b/examples/perseus/wide-resnet.gif differ diff --git a/mkdocs.yml b/mkdocs.yml index d06f10f2..20d60caa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -111,6 +111,9 @@ nav: - getting_started/index.md - Environment Setup: getting_started/environment.md - Installing and Building: getting_started/installing_and_building.md + - Perseus: + - perseus/index.md + - Integrating: perseus/integrating.md - Extending Zeus: extend.md - Source Code Reference: reference/ diff --git a/pyproject.toml b/pyproject.toml index 89b3b083..2b675eeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,11 @@ dependencies = [ "nvidia-ml-py", "pydantic<2", "rich", + "tyro", + "fastapi[all]==0.87.0", + "httpx", + "aiofiles==22.1.0", + "lowtime", ] dynamic = ["version"] @@ -38,8 +43,7 @@ Documentation = "https://ml.energy/zeus" [project.optional-dependencies] lint = ["ruff", "black==22.6.0"] test = ["pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1"] -torch = ["torch==2.0.1"] -dev = ["ruff", "black==22.6.0", "pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1", "torch==2.0.1"] +dev = ["ruff", "black==22.6.0", "pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1"] [tool.setuptools.packages.find] where = ["."] @@ -76,6 +80,8 @@ convention = "google" [tool.ruff.per-file-ignores] "**/__init__.py" = ["F401", "F403"] +"zeus/optimizer/perseus/common.py" = ["N805"] +"zeus/optimizer/perseus/server/router.py" = ["B008"] [tool.pytest.ini_options] addopts = "--numprocesses auto" diff --git a/zeus/__init__.py b/zeus/__init__.py index 7edee2bd..141f7088 100644 --- a/zeus/__init__.py +++ b/zeus/__init__.py @@ -26,4 +26,4 @@ - [`util`][zeus.util]: Utility functions and classes. """ -__version__ = "0.7.1" +__version__ = "0.8.0" diff --git a/zeus/analyze.py b/zeus/analyze.py index 6f64f32b..702389d4 100644 --- a/zeus/analyze.py +++ b/zeus/analyze.py @@ -111,7 +111,7 @@ def avg_power( seconds = _get_seconds(df) watts = _get_watts(df) area = auc(seconds, watts) - return area / (seconds.max() - seconds.min()) + return area / (max(seconds) - min(seconds)) def _get_seconds(df: pd.DataFrame) -> pd.Series: diff --git a/zeus/callback.py b/zeus/callback.py index 45904282..ace972d1 100644 --- a/zeus/callback.py +++ b/zeus/callback.py @@ -41,6 +41,12 @@ def on_step_end(self) -> None: def on_evaluate(self, metric: float) -> None: """Called after evaluating the model.""" + def on_instruction_begin(self, name: str) -> None: + """Called at the beginning of pipeline instructions like forward or backward.""" + + def on_instruction_end(self, name: str) -> None: + """Called at the end of pipeline instructions like forward or backward.""" + class CallbackSet(Callback): """A set of callbacks.""" @@ -83,3 +89,13 @@ def on_evaluate(self, metric: float) -> None: """Called after evaluating the model.""" for callback in self.callbacks: callback.on_evaluate(metric) + + def on_instruction_begin(self, name: str) -> None: + """Called at the beginning of pipeline instructions like forward or backward.""" + for callback in self.callbacks: + callback.on_instruction_begin(name) + + def on_instruction_end(self, name: str) -> None: + """Called at the end of pipeline instructions like forward or backward.""" + for callback in self.callbacks: + callback.on_instruction_end(name) diff --git a/zeus/monitor/energy.py b/zeus/monitor/energy.py index 09ce5a42..c0e13cf6 100644 --- a/zeus/monitor/energy.py +++ b/zeus/monitor/energy.py @@ -18,7 +18,6 @@ import os import atexit -import logging from time import time from pathlib import Path from dataclasses import dataclass @@ -31,6 +30,8 @@ from zeus.util.framework import cuda_sync from zeus.util.env import resolve_gpu_indices +logger = get_logger(__name__) + @dataclass class Measurement: @@ -141,20 +142,19 @@ def __init__( self.gpu_handles[gpu_index] = handle # Initialize loggers. - self.logger = get_logger(type(self).__name__) if log_file is None: self.log_file = None else: if dir := os.path.dirname(log_file): os.makedirs(dir, exist_ok=True) self.log_file = open(log_file, "w") - self.logger.info("Writing measurement logs to %s.", log_file) + logger.info("Writing measurement logs to %s.", log_file) self.log_file.write( f"start_time,window_name,elapsed_time,{','.join(map(lambda i: f'gpu{i}_energy', self.gpu_indices))}\n", ) self.log_file.flush() - self.logger.info("Monitoring GPU %s.", self.gpu_indices) + logger.info("Monitoring GPU indices %s.", self.gpu_indices) # A dictionary that maps the string keys of active measurement windows to # the state of the measurement window. Each element in the dictionary is a tuple of: @@ -230,7 +230,7 @@ def begin_window(self, key: str, sync_cuda: bool = True) -> None: # Add measurement state to dictionary. self.measurement_states[key] = (timestamp, energy_state) - self._log(f"Measurement window '{key}' started.") + logger.debug("Measurement window '%s' started.", key) def end_window( self, key: str, sync_cuda: bool = True, cancel: bool = False @@ -269,7 +269,7 @@ def end_window( # If the measurement window is cancelled, return an empty Measurement object. if cancel: - self._log(f"Measurement window '{key}' cancelled.") + logger.debug("Measurement window '%s' cancelled.", key) return Measurement(time=0.0, energy={gpu: 0.0 for gpu in self.gpu_handles}) end_time: float = time() @@ -300,7 +300,7 @@ def end_window( time_consumption - power_measurement_time ) - self._log(f"Measurement window '{key}' ended.") + logger.debug("Measurement window '%s' ended.", key) # Add to log file. if self.log_file is not None: @@ -312,18 +312,3 @@ def end_window( self.log_file.flush() return Measurement(time_consumption, energy_consumption) - - def _log( - self, message: str, gpu_index: int | None = None, level: int = logging.INFO - ) -> None: - """Print out message with prefix. - - Args: - message: The message to log out. - gpu_index: The index of GPU for GPU-level logging. Should be `None` - when logging global information. (Default: `None`) - level: The logging level to use. (Default: `logging.INFO`) - """ - if gpu_index is not None: - message = f"[GPU {gpu_index}] {message}" - self.logger.log(level, message) diff --git a/zeus/optimizer/perseus/__init__.py b/zeus/optimizer/perseus/__init__.py new file mode 100644 index 00000000..9bfcb671 --- /dev/null +++ b/zeus/optimizer/perseus/__init__.py @@ -0,0 +1,20 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimizer that schedules energy consumption with Perseus. + +Currently, this optimizer depends on PyTorch. +""" + +from zeus.optimizer.perseus.optimizer import PerseusOptimizer diff --git a/zeus/optimizer/perseus/common.py b/zeus/optimizer/perseus/common.py new file mode 100644 index 00000000..67d5eb90 --- /dev/null +++ b/zeus/optimizer/perseus/common.py @@ -0,0 +1,308 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared constants and models between the Perseus server and the client (optimizer).""" + +from __future__ import annotations + +import os +import inspect +from datetime import datetime +from typing import Any, Optional + +import aiofiles +import pandas as pd +from pydantic import BaseModel, BaseSettings, Field, validator, PyObject + +GET_SERVER_INFO_URL = "/info" +REGISTER_JOB_URL = "/register_job" +REGISTER_RANK_URL = "/register_rank/{job_id}" +GET_FREQUENCY_SCHEDULE_URL = "/schedule/{job_id}" +REPORT_PROFILING_RESULT_URL = "/result/{job_id}" + + +class PerseusSettings(BaseSettings): + """Perseus settings, configurable via environment variables. + + For instance, setting `PERSEUS_SCHEDULER=AllMaxFrequency` will automatically + import `zeus.optimizer.perseus.server.scheduler.AllMaxFrequency` and + the `scheduler` variable will hold it a reference to the class. + + Attributes: + scheduler: Name of the `FrequencyScheduler` to use. + scheduler_args: Any extra arguments required by `scheduler.__init__`. + log_level: Log level, e.g. "debug", "info". + dump_data: Whether the scheduler should dump internal state to the filesystem + (for future inspection purposes). + dump_dir: Directory to dump state in (if enabled) + max_job_idle_time: Maximum time in seconds that a job can be idle for before + its states are automatically deleted from the server. + """ + + scheduler: PyObject = "PointSolution" # type: ignore + scheduler_args: dict[str, Any] = {} + log_level: str = "DEBUG" + dump_data: bool = True + dump_dir: str = "./dump" + max_job_idle_time: int = 60 * 60 * 24 * 7 # 1 week + + @validator("scheduler", pre=True) + def _fix_scheduler_import_path(cls, value): + """Prepend `zeus.optimizer.perseus.server.scheduler.` to the scheduler type name.""" + return f"zeus.optimizer.perseus.server.scheduler.{value}" + + @validator("scheduler_args") + def _validate_scheduler_args(cls, args, values): + """Check whether args are as expected by the scheduler's constructor.""" + scheduler = values["scheduler"] + full_args = args | dict(job_info=None, rank_infos=None, perseus_settings=None) + constructor_args = inspect.signature(scheduler) + try: + constructor_args.bind(**full_args) + except TypeError as e: + raise ValueError(f"Invalid scheduler args: {e}") from None + return args + + @validator("log_level") + def _make_upper_case(cls, value): + return value.upper() + + class Config: + """Configuration class read by pydantic.""" + + env_prefix = "perseus_" + + +class JobInfo(BaseModel): + """Training job information reported to the server. + + Attributes: + job_id: Globally unique ID of the training job, generated by the server. + This field should be an empty string when sent to the server. + pp_degree: Pipeline parallel degree. + dp_degree: Data parallel degree. + tp_degree: Tensor parallel degree. + world_size: World size of the training job. + job_metadata: An optional arbitrary string that describes the job. This will + be appended to the job ID if given. Typically for logging purposes. + """ + + job_id: str = "" + pp_degree: int = Field(ge=1) + dp_degree: int = Field(ge=1) + tp_degree: int = Field(ge=1) + world_size: int = Field(ge=1) + job_metadata: Optional[str] = None + + @validator("job_id") + def _check_empty_job_id(cls, job_id): + assert not job_id + return job_id + + @validator("world_size") + def _check_world_size(cls, world_size, values): + """Product of PP, DP, and TP degree would be identical to the world size.""" + assert ( + values["pp_degree"] * values["dp_degree"] * values["tp_degree"] + == world_size + ) + return world_size + + def set_job_id(self, scheduler_name: str): + """Generate and set the job ID.""" + self.job_id = "+".join( + [ + datetime.now().strftime("%F-%H-%M-%S"), + f"dp{self.dp_degree}", + f"pp{self.pp_degree}", + f"tp{self.tp_degree}", + scheduler_name, + ] + ) + if self.job_metadata: + self.job_id += f"+{self.job_metadata}" + + +class RankInfo(BaseModel): + """Information passed to the server from each rank. + + Attributes: + rank: Global rank of the reporting process. + dp_rank: Data parallel rank of the reporting procees. + pp_rank: Pipeline parallel rank of the reporting procees. + tp_rank: Tensor parallel rank of the reporting procees. + available_frequencies: List of available frequencies for the rank's GPU. + """ + + rank: int = Field(ge=0) + dp_rank: int = Field(ge=0) + pp_rank: int = Field(ge=0) + tp_rank: int = Field(ge=0) + available_frequencies: list[int] + + +class FrequencySchedule(BaseModel): + """Frequency schedule for one iteration. + + `frequencies` is a list of tuples, where the first element is the name of the + instruction and the second element is the frequency to use for that instruction. + """ + + rank: int = Field(ge=0) + frequencies: list[tuple[str, int]] + + +class ProfilingResult(BaseModel): + """Profiling results for a `FrequencySchedule` of a rank. + + Attributes: + rank: Global rank of the reporting client. + iter_time: List of latency of all iterations within the profiling window in seconds. + iter_energy: List of energy consumption of all iterations within the profiling window in Joules. + time_breakdown: Duration of each operation across multiple iterations. + e.g. `time_breakdown["forward"][i]` is the list of latencies of all forward computations + in the `i`th iteration. + energy_breakdown: Energy consumption of each operation across multple iterations. + Value has the same structure as `time_breakdown`. + """ + + rank: int = Field(ge=0) + iter_time: list[float] + iter_energy: list[float] + time_breakdown: dict[str, list[list[float]]] = {} + energy_breakdown: dict[str, list[list[float]]] = {} + + +class OfflineProfilingResult(BaseModel): + """Profiling results generated from offline profiling each instruction. + + Attributes: + rank: Global rank of the reporting client. + dp_rank: Data parallel rank of the reporting procees. + pp_rank: Pipeline parallel rank of the reporting procees. + tp_rank: Tensor parallel rank of the reporting procees. + forward_time: Dict that maps frequency to average forward computation time. + forward_energy: Dict that maps frequency to average forward computation energy. + backward_time: Dict that maps frequency to average backward computation time. + backward_energy: Dict that maps frequency to average backward computation energy. + """ + + rank: int = Field(ge=0) + dp_rank: int = Field(ge=0) + pp_rank: int = Field(ge=0) + tp_rank: int = Field(ge=0) + forward_time: dict[int, float] + forward_energy: dict[int, float] + backward_time: dict[int, float] + backward_energy: dict[int, float] + + +class InstructionProfilingResult(BaseModel): + """Time and energy profiling results for each instruction in each stage.""" + + __root__: list[OfflineProfilingResult] + + def to_csv(self, filepath: str) -> None: + """Serialize and save this object into a CSV file. + + Columns: rank, dp_rank, pp_rank, tp_rank, stage, instruction, frequency, time, energy + Notes + - `rank` is the global rank of the process. + - `pp_rank` and `stage` are always the same, for backwards compatibility. + - All ranks and `stage` are zero-indexed. + - `instruction` is either "forward" or "backward". + - `time` and `energy` are already averaged over profiling iterations. + """ + if not filepath.endswith(".csv"): + raise ValueError("Filepath does not end with '.csv'") + + # fmt: off + headers = ["rank", "dp_rank", "pp_rank", "tp_rank", "stage", "instruction", "frequency", "time", "energy"] + records: list[tuple[int, int, int, int, int, str, int, float, float]] = [] + for res in self.__root__: + prefix = (res.rank, res.dp_rank, res.pp_rank, res.tp_rank, res.pp_rank) + for freq in res.forward_time: + records.append((*prefix, "forward", freq, res.forward_time[freq], res.forward_energy[freq])) + for freq in res.backward_time: + records.append((*prefix, "backward", freq, res.backward_time[freq], res.backward_energy[freq])) + # fmt: on + + df = pd.DataFrame.from_records(records, columns=headers) + df.to_csv(filepath, index=False) + + +async def save_prof( + data: list[ProfilingResult], + directory: str, + schedule_num: int, +) -> None: + """Save a list of `ProfilingResult`s in the designated directory.""" + os.makedirs(directory, exist_ok=True) + async with aiofiles.open(f"{directory}/{schedule_num}.prof.json", "w") as f: + obj = _ProfilingResultList(__root__=data).json() + await f.write(obj) + + +def load_prof(directory: str, schedule_num: int) -> list[ProfilingResult]: + """Load a list of `ProfilingResult`s saved in the designated directory.""" + filepath = f"{directory}/{schedule_num}.prof.json" + return _ProfilingResultList.parse_file(filepath).__root__ + + +async def save_sched( + data: list[FrequencySchedule], + directory: str, + schedule_num: int, +) -> None: + """Save a list of `FrequencySchedule`s in the designated directory.""" + os.makedirs(directory, exist_ok=True) + async with aiofiles.open(f"{directory}/{schedule_num}.sched.json", "w") as f: + obj = _FrequencyScheduleList(__root__=data).json() + await f.write(obj) + + +def load_sched(directory: str, schedule_num: int) -> list[FrequencySchedule]: + """Load a list of `FrequencySchedule`s saved in the designated directory.""" + filepath = f"{directory}/{schedule_num}.sched.json" + return _FrequencyScheduleList.parse_file(filepath).__root__ + + +async def save_ranks(data: list[RankInfo], directory: str) -> None: + """Save a list of `RankInfo`s in the designated directory.""" + os.makedirs(directory, exist_ok=True) + async with aiofiles.open(f"{directory}/ranks.json", "w") as f: + obj = _RankInfoList(__root__=data).json() + await f.write(obj) + + +def load_ranks(directory: str) -> list[RankInfo]: + """Load a list of `RankInfo`s saved in the designated directory.""" + filepath = f"{directory}/ranks.json" + return _RankInfoList.parse_file(filepath).__root__ + + +# Proxy classes for a list of Pydantic objects. +# __root__ is making use of Pydantic's Custom Root Type for a cleaner JSON representation. + + +class _ProfilingResultList(BaseModel): + __root__: list[ProfilingResult] + + +class _FrequencyScheduleList(BaseModel): + __root__: list[FrequencySchedule] + + +class _RankInfoList(BaseModel): + __root__: list[RankInfo] diff --git a/zeus/optimizer/perseus/frequency_controller.py b/zeus/optimizer/perseus/frequency_controller.py new file mode 100644 index 00000000..4deefe27 --- /dev/null +++ b/zeus/optimizer/perseus/frequency_controller.py @@ -0,0 +1,89 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Controller that sets the GPU's frequency in a non-blocking fashion.""" + +from __future__ import annotations + +import atexit +import contextlib +import multiprocessing as mp + +import pynvml + + +class FrequencyController: + """Spawns a separate process that sets the GPU frequency.""" + + def __init__(self, nvml_device_id: int = 0) -> None: + """Instantiate the frequency controller. + + Args: + nvml_device_id: The NVML device ID of the GPU to control. + """ + self._q: mp.Queue[int | None] = mp.Queue() + self._proc = mp.Process(target=self._controller_process, args=(nvml_device_id,)) + + atexit.register(self.end) + self._proc.start() + + def set_frequency(self, frequency: int) -> None: + """Set the GPU's frequency asynchronously. + + If `frequency` is zero, returns without doing anything. + """ + if frequency != 0: + self._q.put(frequency, block=False) + + def end(self) -> None: + """Stop the controller process.""" + self._q.put(None, block=False) + + def _controller_process(self, device_id: int) -> None: + """Receive frequency values through a queue and apply it.""" + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + + # Return the power limit to the default. + pynvml.nvmlDeviceSetPowerManagementLimit( + handle, + pynvml.nvmlDeviceGetPowerManagementDefaultLimit(handle), + ) + + # Set the memory frequency to be the highest. + max_mem_freq = max(pynvml.nvmlDeviceGetSupportedMemoryClocks(handle)) + with contextlib.suppress(pynvml.NVMLError_NotSupported): # type: ignore + pynvml.nvmlDeviceSetMemoryLockedClocks(handle, max_mem_freq, max_mem_freq) + + # Set the SM frequency to be the highest. + max_freq = max( + pynvml.nvmlDeviceGetSupportedGraphicsClocks(handle, max_mem_freq) + ) + pynvml.nvmlDeviceSetGpuLockedClocks(handle, max_freq, max_freq) + current_freq = max_freq + + # Wait on the queue for the next frequency to set. + while True: + target_freq = self._q.get(block=True) + if target_freq is None: + break + if current_freq != target_freq: + pynvml.nvmlDeviceSetGpuLockedClocks(handle, target_freq, target_freq) + current_freq = target_freq + + # Reset everything. + with contextlib.suppress(pynvml.NVMLError_NotSupported): # type: ignore + pynvml.nvmlDeviceResetMemoryLockedClocks(handle) + pynvml.nvmlDeviceResetGpuLockedClocks(handle) + pynvml.nvmlShutdown() diff --git a/zeus/optimizer/perseus/optimizer.py b/zeus/optimizer/perseus/optimizer.py new file mode 100644 index 00000000..ce4d87c1 --- /dev/null +++ b/zeus/optimizer/perseus/optimizer.py @@ -0,0 +1,236 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Perseus optimizer implementation. + +The `PerseusOptimizer` is to be integrated into the user-side framework. +It is responsible for communicating with the Perseus server and managing +the `FrequencyController` instance, which is responsible for controlling +the frequency of the CPU of the current process. +""" + +from __future__ import annotations + +import httpx +import pynvml +import torch +import torch.distributed as dist + +from zeus.callback import Callback +from zeus.optimizer.perseus.frequency_controller import FrequencyController +from zeus.optimizer.perseus.common import ( + GET_FREQUENCY_SCHEDULE_URL, + REGISTER_JOB_URL, + REGISTER_RANK_URL, + JobInfo, + RankInfo, + FrequencySchedule, +) +from zeus.util.env import resolve_gpu_indices +from zeus.util.framework import cuda_sync + + +class PerseusOptimizer(Callback): + """Perseus optimizer.""" + + def __init__( + self, + rank: int, + dp_rank: int, + pp_rank: int, + tp_rank: int, + device_id: int, + dp_degree: int, + pp_degree: int, + tp_degree: int, + world_size: int, + server_url: str, + job_metadata: str | None = None, + ) -> None: + """Initialize the Perseus optimizer. + + Assumptions: + - `torch.distributed` has been initialized. + - `torch.cuda.set_device` has been called with `device_id`. + This is needed to broadcast the job ID to all ranks. + + The master process (rank 0) will register the job with the Peresus + server and retrieve the job ID of this job. Then, each rank will + report itself to the Perseus server with the job ID. + + Args: + rank: Global rank of the current process. + dp_rank: Rank in the data parallel group. + pp_rank: Rank in the pipeline parallel group. + tp_rank: Rank in the tensor parallel group. + device_id: CUDA device ID that the current process manages. + dp_degree: Size of the data parallel group. + pp_degree: Size of the pipeline parallel group. + tp_degree: Size of the tensor parallel group. + world_size: Total number of ranks that participate in training. + server_url: URL of the Perseus server. + job_metadata: An optional arbitrary string that describes the job. This will + be appended to the job ID if given. Typically for logging purposes. + """ + if not dist.is_initialized(): + raise RuntimeError( + "Instantiate `PerseusOptimizer` after `init_process_group`." + ) + + self.server_url = server_url + self.rank = rank + self.dp_rank = dp_rank + self.pp_rank = pp_rank + self.tp_rank = tp_rank + + cuda_device_ids, nvml_device_ids = resolve_gpu_indices([device_id]) + self.cuda_device_id = cuda_device_ids[0] + nvml_device_id = nvml_device_ids[0] + # It is assumed that `torch.cuda.set_device` has been called with `device_id`. + # It won't hurt to call this again. + torch.cuda.set_device(self.cuda_device_id) + + # Rank 0 registers the job with the Perseus server and retrieves the job ID. + job_id = None + if rank == 0: + job_info = JobInfo( + pp_degree=pp_degree, + dp_degree=dp_degree, + tp_degree=tp_degree, + world_size=world_size, + job_metadata=job_metadata, + ) + response = httpx.post( + self.server_url + REGISTER_JOB_URL, json=job_info.dict() + ) + if (code := response.status_code) != 200: + raise RuntimeError( + f"Perseus server returned status code {code}: {response.text}" + ) + job_id = response.json() + if not isinstance(job_id, str): + raise RuntimeError( + f"Perseus server returned a strange job ID: {job_id=}" + ) + + # Rank 0 broadcasts the job ID across all ranks. + objects = [job_id] + dist.broadcast_object_list(objects, src=0) + self.job_id = objects[0] + if self.job_id is None: + raise RuntimeError("Failed to broadcast job ID to all ranks") + + # Query the list of available frequencies of the GPU. + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(nvml_device_id) + max_mem_freq = max(pynvml.nvmlDeviceGetSupportedMemoryClocks(handle)) + freqs = sorted( + pynvml.nvmlDeviceGetSupportedGraphicsClocks(handle, max_mem_freq), + reverse=True, + ) + pynvml.nvmlShutdown() + + # Each rank reports itself to the Perseus server with the job ID. + rank_info = RankInfo( + rank=self.rank, + dp_rank=self.dp_rank, + pp_rank=self.pp_rank, + tp_rank=self.tp_rank, + available_frequencies=freqs, + ) + response = httpx.post( + self.server_url + REGISTER_RANK_URL.format(job_id=self.job_id), + json=rank_info.dict(), + ) + if (code := response.status_code) != 200: + raise RuntimeError( + f"Perseus server returned status code {code}: {response.text}" + ) + + # The frequency controller is responsible for controlling the frequency + # of the GPU (nvml_device_id) asynchronously. + self.frequency_controller = FrequencyController(nvml_device_id=nvml_device_id) + + # Fetch the frequency schedule from the Perseus server. + self.freq_schedule = self._get_frequency_schedule() + self.freq_schedule_iter = iter(self.freq_schedule) + + def _get_frequency_schedule(self) -> list[tuple[str, int]]: + """Get the frequency schedule from the Perseus server.""" + response = httpx.get( + self.server_url + GET_FREQUENCY_SCHEDULE_URL.format(job_id=self.job_id), + params={"rank": self.rank}, + timeout=None, + ) + if (code := response.status_code) != 200: + raise RuntimeError( + f"Perseus server returned status code {code}: {response.text}" + ) + schedule = FrequencySchedule.parse_raw(response.text) + if schedule.rank != self.rank: + raise RuntimeError( + f"Perseus server returned a schedule for rank {schedule.rank} to rank {self.rank}" + ) + return schedule.frequencies + + def on_step_begin(self) -> None: + """Mark the beginning of a step. + + TODO(jaywonchung): InstructionProfiler iteration start mark. + """ + pass + + def on_step_end(self) -> None: + """Mark the end of a step. + + TODO(jaywonchung): InstructionProfiler iteration end mark. + Also report the profiling result to the Perseus server after N iterations. + """ + # Frequency schedule holds one iteration-worth of frequencies, so at + # the end of each iteration, the iterator should be exhausted. + item = next(self.freq_schedule_iter, None) + if item is not None: + raise RuntimeError( + "Perseus server returned more frequencies than expected. " + f"Next expected instruction and frequency is {item}" + ) + self.freq_schedule_iter = iter(self.freq_schedule) + + def on_instruction_begin(self, name: str) -> None: + """Mark the beginning of an instruction, like forward and backward. + + Retrieve the next frequency from the schedule, check whether the next + expected instruction matches the name of the instruction, and set the + frequency accordingly. + """ + cuda_sync(self.cuda_device_id) + + # Retrieve the next frequency from the schedule. + item = next(self.freq_schedule_iter, None) + if item is None: + raise RuntimeError( + "Perseus server returned fewer frequencies than expected" + ) + + # Check whether the next expected instruction matches the name of the instruction. + instruction, frequency = item + if instruction != name: + raise RuntimeError( + f"The next expected instruction is not forward: {instruction}" + ) + + self.frequency_controller.set_frequency(frequency) + + def on_instruction_end(self, _: str) -> None: + """Mark the end of an instruction, like forward and backward.""" diff --git a/zeus/optimizer/perseus/server/__init__.py b/zeus/optimizer/perseus/server/__init__.py new file mode 100644 index 00000000..3331ddaa --- /dev/null +++ b/zeus/optimizer/perseus/server/__init__.py @@ -0,0 +1,22 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The Perseus server guides the PerseusOptimizer with frequency plans. + +The server is agnostic to the training framework the PerseusOptimizer +is integrated with. A server is useful because large model training is +typically distributed, and we still need one place to coordinate the +frequency plans. Later, the server will be extended to support complete +online profiling and optimization. +""" diff --git a/zeus/optimizer/perseus/server/job_manager.py b/zeus/optimizer/perseus/server/job_manager.py new file mode 100644 index 00000000..8cdf6656 --- /dev/null +++ b/zeus/optimizer/perseus/server/job_manager.py @@ -0,0 +1,240 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The JobManager singleton class manages all job states.""" + +from __future__ import annotations + +import time +import asyncio +import traceback + +from fastapi import HTTPException + +from zeus.optimizer.perseus.common import ( + JobInfo, + PerseusSettings, + FrequencySchedule, + ProfilingResult, + RankInfo, + save_prof, + save_sched, + save_ranks, +) +from zeus.util.logging import get_logger +from zeus.util.async_utils import create_task + +GLOBAL_JOB_MANAGER: JobManager | None = None + +logger = get_logger(__name__) + + +class JobManager: + """A singleton class that manages all states.""" + + def __init__(self, perseus_settings: PerseusSettings) -> None: + """Initialize the job manager.""" + self.perseus_settings = perseus_settings + + self._job_infos: dict[str, JobInfo] = {} + self._job_rank_infos: dict[str, list[RankInfo]] = {} + self._job_tasks: dict[str, asyncio.Task] = {} + self._job_result_channels: dict[str, asyncio.Queue[ProfilingResult]] = {} + self._job_sched_request_channels: dict[str, asyncio.Queue] = {} + self._job_sched_response_channels: dict[str, list[asyncio.Queue]] = {} + self._job_last_active_time: dict[str, float] = {} + + # Spawn cleanup task that evicts the state of jobs that have not been active + # for a long time. + create_task( + self._cleanup_task( + cleanup_period=60, + max_idle_time=perseus_settings.max_job_idle_time, + ), + logger=logger, + ) + + def register_job(self, job_info: JobInfo) -> None: + """Prepare internal state for a new job. + + This method will be invoked exactly once by the global rank 0 (master) process. + """ + job_id = job_info.job_id + world_size = job_info.world_size + self._job_infos[job_id] = job_info + self._job_rank_infos[job_id] = [] + self._job_result_channels[job_id] = asyncio.Queue(maxsize=world_size) + self._job_sched_request_channels[job_id] = asyncio.Queue(maxsize=world_size) + self._job_sched_response_channels[job_id] = [ + asyncio.Queue(maxsize=1) for _ in range(world_size) + ] + self._job_tasks[job_id] = create_task( + self._job_task(job_id, self.perseus_settings.dump_data), + logger=logger, + ) + self._job_last_active_time[job_id] = time.monotonic() + + def register_rank(self, job_id: str, rank_info: RankInfo) -> None: + """Register rank-specific information for an already registered job. + + This method will be invoked `world_size` number of times (once per rank). + """ + self._job_rank_infos[job_id].append(rank_info) + self._job_last_active_time[job_id] = time.monotonic() + + async def get_frequency_schedule(self, job_id: str, rank: int) -> FrequencySchedule: + """Get the next frequency schedule for a rank. + + This method will be called `world_size` number of times (once per rank). + All ranks will block on this method untill everyone reports their + profiling results and calls this method. + + When an internal scheduler error happened at any point of servicing the + job, clients will be notified through this API with a 500 Internal Error. + """ + await self._job_sched_request_channels[job_id].put(rank) + res = await self._job_sched_response_channels[job_id][rank].get() + if isinstance(res, Exception): + code = 400 if isinstance(res, ValueError) else 500 + raise HTTPException( + status_code=code, + detail="".join( + traceback.format_exception(type(res), res, res.__traceback__) + ), + ) + self._job_last_active_time[job_id] = time.monotonic() + return res + + def report_profiling_result(self, job_id: str, result: ProfilingResult) -> None: + """Send the profiling result to the job task and immediately return. + + This method will be called `world_size` number of times - one for each rank. + """ + self._job_result_channels[job_id].put_nowait(result) + self._job_last_active_time[job_id] = time.monotonic() + + async def _cleanup_task( + self, + cleanup_period: int, + max_idle_time: int, + ) -> None: + """Periodically evict job states. + + Args: + cleanup_period: How often to run the cleanup task, in seconds. + max_idle_time: Maximum amount of time a job can be idle for, in seconds. + """ + while True: + await asyncio.sleep(cleanup_period) + for job_id in list(self._job_last_active_time.keys()): + if ( + time.monotonic() - self._job_last_active_time[job_id] + > max_idle_time + ): + self._job_tasks[job_id].cancel() + del self._job_infos[job_id] + del self._job_rank_infos[job_id] + del self._job_result_channels[job_id] + del self._job_sched_request_channels[job_id] + del self._job_sched_response_channels[job_id] + del self._job_tasks[job_id] + del self._job_last_active_time[job_id] + + async def _job_task(self, job_id: str, dump_data: bool) -> None: + """Coalese requests and responses of each rank and interface with the scheduler.""" + result_chan = self._job_result_channels[job_id] + sched_req_chan = self._job_sched_request_channels[job_id] + sched_resp_chan = self._job_sched_response_channels[job_id] + + job_info = self._job_infos[job_id] + + try: + # Wait until all ranks have reported their `RankInfo`s. + rank_infos = self._job_rank_infos[job_id] + while True: + await asyncio.sleep(0.1) + # Indexing the first element is always safe because this task is + # created after putting the `RankInfo` of the first-connected rank + # in `self.job_rank_infos[job_id]`. + if len(rank_infos) == job_info.world_size: + break + + # Sort `RankInfo`s in rank order. + rank_infos.sort(key=lambda r: r.rank) + + # Create directory to dump Perseus states. + dump_dir = f"{self.perseus_settings.dump_dir}/{job_id}" + if dump_data: + await save_ranks(rank_infos, dump_dir) + + # Instantiate the frequency scheduler. + scheduler = self.perseus_settings.scheduler( + job_info, + rank_infos, + self.perseus_settings, + **self.perseus_settings.scheduler_args, + ) + + # Provide next schedules, observe profiling results, and repeat. + schedule_num = 0 + while True: + # Compute the next `FrequencySchedule`s. + schedules = scheduler.next_schedule() + + # Wait until all the ranks ask for the next schedule. + await asyncio.gather(*[sched_req_chan.get() for _ in rank_infos]) + + # Send out `FrequencySchedule`s. + await asyncio.gather( + *[sched_resp_chan[s.rank].put(s) for s in schedules] + ) + + # Gather profiling results from all ranks. + results = await asyncio.gather(*[result_chan.get() for _ in rank_infos]) + results.sort(key=lambda r: r.rank) + + # Dump profiling results and schedules. + if dump_data: + schedules.sort(key=lambda s: s.rank) + await save_prof(results, dump_dir, schedule_num) + await save_sched(schedules, dump_dir, schedule_num) + + # Send `ProfilingResult`s to the scheduler. + scheduler.observe(results) + + # Increment schedule number. + schedule_num += 1 + + except asyncio.CancelledError: + # This task gets cancelled when it's idle for too long and evicted. + pass + + except Exception as exc: + # In case the scheduler errored, send out the exception to the clients. + # The clients will receive the error when they ask for the next schedule. + for chan in sched_resp_chan: + chan.put_nowait(exc) + raise + + +def init_global_job_manager(perseus_settings: PerseusSettings) -> None: + """Instantiate the global singleton `JobManager`.""" + global GLOBAL_JOB_MANAGER + GLOBAL_JOB_MANAGER = JobManager(perseus_settings=perseus_settings) + + +def get_global_job_manager() -> JobManager: + """Fetch the global singleton `JobManager`.""" + assert GLOBAL_JOB_MANAGER is not None, "`init_global_job_manager` was not called." + return GLOBAL_JOB_MANAGER diff --git a/zeus/optimizer/perseus/server/router.py b/zeus/optimizer/perseus/server/router.py new file mode 100644 index 00000000..295f00fe --- /dev/null +++ b/zeus/optimizer/perseus/server/router.py @@ -0,0 +1,118 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Perseus server FastAPI router.""" + +from __future__ import annotations + +import logging +from typing import Callable + +from fastapi import Depends, FastAPI, Response, Request +from fastapi.routing import APIRoute + +from zeus.util.logging import get_logger +from zeus.optimizer.perseus.common import ( + REGISTER_JOB_URL, + REGISTER_RANK_URL, + GET_FREQUENCY_SCHEDULE_URL, + REPORT_PROFILING_RESULT_URL, + JobInfo, + RankInfo, + PerseusSettings, + ProfilingResult, + FrequencySchedule, +) +from zeus.optimizer.perseus.server.job_manager import ( + JobManager, + init_global_job_manager, + get_global_job_manager, +) + +logger = get_logger(__name__) +app = FastAPI() + + +class LoggingRoute(APIRoute): + """Route handler that logs out all requests and responses in DEBUG level.""" + + def get_route_handler(self) -> Callable: + """Wrap the original handler with debug messages.""" + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + response: Response = await original_route_handler(request) + logger.debug( + "%s %s: %s -> %s", + request.method, + request.url, + await request.json() if await request.body() else "None", + response.body.decode(response.charset), + ) + return response + + return custom_route_handler + + +settings = PerseusSettings() +logging.basicConfig(level=logging.getLevelName(settings.log_level)) +if logging.getLevelName(settings.log_level) <= logging.DEBUG: + app.router.route_class = LoggingRoute + + +@app.on_event("startup") +async def startup_hook(): + """Startup hook.""" + logger.info("Using scheduler `%s`", settings.scheduler.__name__) + init_global_job_manager(settings) + + +@app.post(REGISTER_JOB_URL, response_model=str) +async def register_job( + job_info: JobInfo, job_manager: JobManager = Depends(get_global_job_manager) +) -> str: + """Register the training job's information in the server.""" + job_info.set_job_id(scheduler_name=settings.scheduler.__name__) + job_manager.register_job(job_info) + return job_info.job_id + + +@app.post(REGISTER_RANK_URL) +async def register_rank( + job_id: str, + rank_info: RankInfo, + job_manager: JobManager = Depends(get_global_job_manager), +) -> None: + """Register each rank's information in the server.""" + job_manager.register_rank(job_id, rank_info) + + +@app.get(GET_FREQUENCY_SCHEDULE_URL, response_model=FrequencySchedule) +async def get_frequency_schedule( + job_id: str, + rank: int, + job_manager: JobManager = Depends(get_global_job_manager), +) -> FrequencySchedule: + """Return the next frequency schedule for the rank.""" + return await job_manager.get_frequency_schedule(job_id, rank) + + +@app.post(REPORT_PROFILING_RESULT_URL) +async def report_profiling_result( + job_id: str, + profiling_result: ProfilingResult, + job_manager: JobManager = Depends(get_global_job_manager), +) -> None: + """Report the profiling result for the most recent frequency schedule.""" + job_manager.report_profiling_result(job_id, profiling_result) diff --git a/zeus/optimizer/perseus/server/scheduler.py b/zeus/optimizer/perseus/server/scheduler.py new file mode 100644 index 00000000..86841abf --- /dev/null +++ b/zeus/optimizer/perseus/server/scheduler.py @@ -0,0 +1,292 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Interfaces for defining frequency schedulers.""" + +from __future__ import annotations + +import copy +from pathlib import Path +from contextlib import suppress +from abc import ABC, abstractmethod +from typing import Callable, Generator, Sequence, Type + +from zeus.optimizer.perseus.common import ( + PerseusSettings, + JobInfo, + RankInfo, + FrequencySchedule, + ProfilingResult, +) +from zeus.util.logging import get_logger + +logger = get_logger(__name__) + + +class FrequencyScheduler(ABC): + """Interface for classes that enclose frequency scheduling policies.""" + + def __init__( + self, + job_info: JobInfo, + rank_infos: list[RankInfo], + perseus_settings: PerseusSettings, + ) -> None: + """Initialize the scheduler. + + Args: + job_info: Info about the training job. + rank_infos: Info about all ranks. May not be sorted in rank order. + perseus_settings: PerseusSettings object. + """ + self.job_info = job_info + self.rank_infos = sorted(rank_infos, key=lambda info: info.rank) + self.world_size = self.job_info.world_size + self.perseus_settings = perseus_settings + + self._generator = self._run() + self._next_schedule: list[FrequencySchedule] | None = None + + def observe(self, profiling_results: list[ProfilingResult]) -> None: + """Ingest the profiling results for the previous schedule. + + Args: + profiling_results: Doesn't have to be sorted in rank order. + """ + # When there are no more schedules left to yield, the generator will + # raise `StopIteration`. We just ignore this, and later invocations of + # `next_schedule()` will return the last schedule returned forever. + with suppress(StopIteration): + self._next_schedule = self._generator.send(profiling_results) + + def next_schedule(self) -> list[FrequencySchedule]: + """Return the schedules for the next round of iterations. + + Returns: + A list of `FrequencySchedule`s. May not be sorted in rank order. + """ + if self._next_schedule is None: + try: + self._next_schedule = next(self._generator) + except StopIteration as exc: + raise RuntimeError( + "The _run generator raised StopIteration on its first next call.", + ) from exc + return self._next_schedule + + @abstractmethod + def _run(self) -> Generator[list[FrequencySchedule], list[ProfilingResult], None]: + """Yield next schedules and receives profiling results in one place. + + This is an alternative way to write a frequency scheduler. The advantage is + that everything is enclosed inside this method. The downside is that you'll + have to read this and understand how this generator works. + + The following implementation is a simple example of writing a scheduler using + this class. `yield` the next frequency schedule, and receive the profiling + results corresponding to that schedule from the `yield`. `observe` and + `next_schedule` will run the generator for you. + + In general, this generator should be designed to `yield` schedules infinitely. + However, if this was written to write a finite number of next schedules and + raise `StopIteration`, the last schedule cached inside `self._next_schedule` + will infinitely be returned from the call to `next_schedule`. This can be + useful when you converge to the optimal schedule and stop the generator, and + the rest of training will run with the final optimal schedule indefinitely. + """ + # This is an example implementation. + while True: + # Generate the next frequency schedule + next_schedule: list[FrequencySchedule] = [] + # Send the next schedule to client and receive the profiling result from client + profiling_results = yield next_schedule + # Ingest the profiling result + logger.debug("%s", profiling_results) + + +def make_3d_parallel( + sched_cls: Type[FrequencyScheduler], name: str | None = None +) -> Type[FrequencyScheduler]: + """Wrap `sched_cls` so that it is aware of 3D parallelism. + + Internally, this function subclasses `sched_cls` and overrides `observe` and + `next_schedule`. `observe` will aggregate the profiling results from all ranks + that share the same pp_rank and feed it to `super().observe`, while `next_schedule` + will first retrieve the per-stage schedule from `super().next_schedule` and then + copy-paste it to all ranks that share the same pp_rank. With this, the wrapped + scheduler can operate under the illusion that it's only deadling with pure pipeline + parallelism. + + Args: + sched_cls: The scheduler class to wrap. + name: Name of the scheduler. If None, use `sched_cls.__name__ + "3D"`. + """ + + class Wrapper(sched_cls): # type: ignore[valid-type,misc] + def __init__( + self, + job_info: JobInfo, + rank_infos: list[RankInfo], + perseus_settings: PerseusSettings, + *args, + **kwargs, + ) -> None: + self._orig_job_info = job_info + self._orig_rank_infos = rank_infos + + # Give the wrapped scheduler a perfect illusion of pure pipeline parallelism + # and no data or tensor parallelism. New rank is given by pp_rank. + job_info = copy.deepcopy(job_info) + job_info.dp_degree = 1 + job_info.tp_degree = 1 + job_info.world_size = job_info.pp_degree + + new_rank_infos = [] + for rank_info in rank_infos: + if rank_info.dp_rank == 0 and rank_info.tp_rank == 0: + new_rank_info = copy.deepcopy(rank_info) + new_rank_info.rank = rank_info.pp_rank + new_rank_infos.append(new_rank_info) + + super().__init__(job_info, rank_infos, perseus_settings, *args, **kwargs) + + def observe(self, profiling_results: list[ProfilingResult]) -> None: + """Aggregate results so that each pipeline stage has one result.""" + # Aggregate results from ranks that share the same pp_rank. + rank_to_pp_rank = { + rank_info.rank: rank_info.pp_rank for rank_info in self._orig_rank_infos + } + pp_results: list[list[ProfilingResult]] = [ + [] for _ in range(self._orig_job_info.pp_degree) + ] + for result in profiling_results: + pp_results[rank_to_pp_rank[result.rank]].append(result) + + # For each stage, construct a new ProfilingResult that aggregates all ranks. + # For iter_time and values in time_breakdown, take the max. + # For iter_energy and values in energy_breakdown, take the sum. + def agg_list(values: Sequence[list[float]], fun: Callable) -> list[float]: + return [fun(vals) for vals in zip(*values)] + + def agg_list_of_list( + values: Sequence[list[list[float]]], fun: Callable + ) -> list[list[float]]: + return [agg_list(vals, fun) for vals in zip(*values)] + + agg_results = [] + for pp_rank, results in enumerate(pp_results): + agg_result = ProfilingResult( + rank=pp_rank, + iter_time=agg_list([result.iter_time for result in results], max), + iter_energy=agg_list( + [result.iter_energy for result in results], sum + ), + time_breakdown={ + key: agg_list_of_list( + [result.time_breakdown[key] for result in results], max + ) + for key in results[0].time_breakdown + }, + energy_breakdown={ + key: agg_list_of_list( + [result.energy_breakdown[key] for result in results], sum + ) + for key in results[0].energy_breakdown + }, + ) + agg_results.append(agg_result) + logger.debug( + "Aggregated rank %s results for pp_rank %d: %s", + ", ".join([str(r.rank) for r in results]), + pp_rank, + agg_result, + ) + + # Finally, let the wrapped scheduler observe the aggregated results. + super().observe(agg_results) + + def next_schedule(self) -> list[FrequencySchedule]: + """Copy and paste the schedule for each stage to all ranks in that stage.""" + # Retrive the next schedule for each stage. + schedules = super().next_schedule() + + # Copy and paste the schedule for each stage to all ranks in that stage. + rank_to_pp_rank = { + rank_info.rank: rank_info.pp_rank for rank_info in self._orig_rank_infos + } + next_schedule = [] + for rank in range(self._orig_job_info.world_size): + pp_rank = rank_to_pp_rank[rank] + sched = copy.deepcopy(schedules[pp_rank]) + sched.rank = rank + next_schedule.append(sched) + logger.debug( + "Copied schedule for pp_rank %d to rank %d: %s", + pp_rank, + rank, + sched, + ) + return next_schedule + + Wrapper.__name__ = name or (sched_cls.__name__ + "3D") + if sched_cls.__doc__ is not None: + Wrapper.__doc__ = "[Wrapped for 3D parallelism] " + sched_cls.__doc__ + + return Wrapper + + +class PointSolution(FrequencyScheduler): + """Runs the given frequency schedule.""" + + def __init__( + self, + job_info: JobInfo, + rank_infos: list[RankInfo], + perseus_settings: PerseusSettings, + solution_path: str, + ) -> None: + """Initialize the scheduler. + + Args: + job_info: Info about the training job. + rank_infos: Info about all ranks. May not be sorted in rank order. + perseus_settings: PerseusSettings object. + solution_path: Path to the frequency Python file generated by lowtime. + """ + super().__init__(job_info, rank_infos, perseus_settings) + + self.solution_path = Path(solution_path) + if not self.solution_path.is_file(): + raise RuntimeError(f"Solution file not found: {solution_path}") + if self.solution_path.suffix != ".py": + raise RuntimeError(f"Solution file is not a Python file: {solution_path}") + + with open(self.solution_path, encoding="utf-8") as f: + schedule: list[list[tuple[str, int]]] = eval(f.read()) + if len(schedule) != self.world_size: + raise RuntimeError( + f"Solution file assumes {len(schedule)} ranks, but " + f"the job has {self.world_size} ranks." + ) + + self.schedule = [] + for rank, freqs in enumerate(schedule): + self.schedule.append(FrequencySchedule(rank=rank, frequencies=freqs)) + + def _run(self) -> Generator[list[FrequencySchedule], list[ProfilingResult], None]: + """Yield the schedule given by the solution path.""" + yield self.schedule + + +PointSolution3D = make_3d_parallel(PointSolution) diff --git a/zeus/simulate.py b/zeus/simulate.py index 62c6b3bf..b5b02878 100644 --- a/zeus/simulate.py +++ b/zeus/simulate.py @@ -321,7 +321,6 @@ class RunningJob: # We need a while loop here because we might have submitted a retry job # while reaping jobs that failed to reach the target metric, and that retry job # may finish before the current job. - # pylint: disable=cell-var-from-loop while any(map(lambda j: j.end_time <= current_time, running_jobs)): if self.verbose: print(f"[Simulator] Running jobs: {running_jobs}") diff --git a/zeus/util/async_utils.py b/zeus/util/async_utils.py new file mode 100644 index 00000000..e75c41cf --- /dev/null +++ b/zeus/util/async_utils.py @@ -0,0 +1,60 @@ +# Copyright (C) 2023 Jae-Won Chung +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for asyncio.""" + +from __future__ import annotations + +import asyncio +import logging +import functools +from typing import Any, Coroutine, TypeVar + +from zeus.util.logging import get_logger + +T = TypeVar("T") +default_logger = get_logger(__name__) + + +def create_task( + coroutine: Coroutine[Any, Any, T], + logger: logging.Logger | None = None, +) -> asyncio.Task[T]: + """Create an `asyncio.Task` but ensure that exceptions are logged. + + Reference: https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/ + + Args: + coroutine: The coroutine to be wrapped. + logger: The logger to be used for logging exceptions. If `None`, the + the logger with the name `zeus.util.async_utils` is used. + """ + loop = asyncio.get_running_loop() + task = loop.create_task(coroutine) + task.add_done_callback( + functools.partial(_handle_task_exception, logger=logger or default_logger) + ) + return task + + +def _handle_task_exception(task: asyncio.Task, logger: logging.Logger) -> None: + """Print out exception and tracebook when a task dies with an exception.""" + try: + task.result() + except asyncio.CancelledError: + # Cancellation should not be logged as an error. + pass + except Exception: + # `logger.exception` automatically handles exception and traceback info. + logger.exception("Job task died with an exception!") diff --git a/zeus_monitor/README.md b/zeus_monitor/README.md index 7da2b43d..3469a618 100644 --- a/zeus_monitor/README.md +++ b/zeus_monitor/README.md @@ -1,5 +1,8 @@ # [Deprecated] Zeus Power Monitor +`zeus_monitor` is deprecated and will be removed in 2024. +Please use the [`ZeusMonitor`](https://ml.energy/zeus/reference/monitor/energy/#zeus.monitor.energy.ZeusMonitor) for programmatic measurement and its CLI counterpart (`python -m zeus.monitor`) for command line measurement. + This is a simple GPU power monitor used by Zeus. It polls NVML and writes outputs to the designated log file path. Find a sample of its output in [`sample.csv`](sample.csv).