Skip to content

Commit

Permalink
Open source Perseus (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung authored Oct 13, 2023
1 parent e78c9a0 commit 076df3d
Show file tree
Hide file tree
Showing 22 changed files with 1,763 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_homepage_build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ jobs:
- name: Install other homepage dependencies
run: pip install -r docs/requirements.txt
- name: Build homepage
run: mkdocs build --verbose --strict
run: mkdocs build --verbose
89 changes: 89 additions & 0 deletions docs/perseus/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Perseus: Energy Scheduling in Large Model Training

!!! Warning
Perseus is under active development, and breaking changes may happen.
Currently, we have all the low-level APIs in place, but it's not a turnkey solution yet.
This document always reflects the master `HEAD`.

## Overview

Perseus finds the training time--energy Pareto frontier of large model training.
Users can pick any point on the frontier -- be it minimum time, minimum energy, or something in the middle, depending on the training deadline.

Large model training requires the distribution of work to multiple GPUs.
The core observation of Perseus is that work cannot be perfectly split and balanced across every GPU; some GPUs have more work to do and some less.
GPUs with smaller amounts of work finish before GPUs with more amounts of work, but ultimately training throughput is bound by GPUs with the most amount of work.
In other words, GPUs with lighter load are running unnecessarily fast and wasting energy (i.e., there is **energy bloat**).

We reduce enregy bloat by controlling the execution speed of each pipeline instruction (forward and backward) in each stage by controlling the GPU's frequency in a fine-grained manner.
We call the assignment of a GPU frequency to each pipeline instruction *frequency plan*, and Perseus gives you **every Pareto-optimal frequency plan** that you can choose any point on the iteration time--energy Pareto frontier.
These plans include frequency plans that do not make training any slower compared to not using Perseus at all, but yield free energy savings.
If you have a bit more leeway as to when training should finish (e.g., You're good as long as training finishes by tomorrow morning), you can pick the frequency plan that slows down training by a couple percentages and save more energy.
Our core algorithm, implemented as a separate library called [`lowtime`](https://github.com/ml-energy/lowtime), **provably guarantees** that for any time deadline, energy consumption is minimal.

## How it's done

Currently it's a three-step process:

1. **Profile**: Profile the computation time and energy consumption of the forward and backward instructions in *each stage* and *each GPU frequency*.
2. **Optimize**: Use [`lowtime`](https://github.com/ml-energy/lowtime) to generate all Pareto-optimal frequency plans.
3. **Choose and start training**: Among all the frequency plans generated by `lowtime`, choose the one that suits your use case.

We have a reference integration with the large model training framework [Merak](https://github.com/ml-energy/merak-zeus), which supports 3D parallelism and automatically tracing and partitioning Hugging Face models.
We've smoothed out some rough edges, integrated Zeus and Perseus, and maintained example scripts for GPT3, BERT, and Wide-ResNet (pretty much any `torchvision` model).

You don't have to be tied to Merak.
If you have your own training framework, and you can integrate Perseus following [our integration guide](integrating.md).

### Profile

In order to run our optimization algorithm, we need the time & energy profiling information of the forward and backward instruction in each stage for every GPU frequency.
The CSV file should look like this for a 4-stage pipeline:

```csv
stage,instruction,frequency,time,energy
0,forward,1740,0.09373254776000976,28.4944
0,forward,1725,0.09390360514322917,28.434366666666666
0,forward,1710,0.09381131331125896,28.288966666666667
...
0,backward,1740,0.24533510557810465,69.5691
0,backward,1725,0.24538559118906658,69.2552
0,backward,1710,0.24548352559407552,68.89453333333334
...
3,backward,690,0.4184921979904175,68.12243333333333
3,backward,675,0.42459266185760497,68.77603333333334
3,backward,660,0.4306272824605306,69.39623333333334
```

Since different frameworks and model implementations will have different performance, it's best to obtain these profiling results on the framework and model you'll be using.
That being said, you can obtain this profiling information in however way you want as long as they have all the columns in the reference CSV file above.
But as a reference, we have implemented an automatic profiler in Merak.
Please refer to the [examples](https://github.com/ml-energy/merak-zeus/tree/main/examples) directory in Merak for profiling instructions.

!!! Tip
As you profile the time and energy consumption of an instruction, you will scan down from the highest to the lowest frequency.
However, as you lower the GPU's frequency, both time and energy will start to inflate after some point.
In other words, those frequencies take more time **and** energy and are simply inefficient (i.e., Pareto-suboptimal), so we won't be running anything with those frequencies.
Therefore, you actually don't need to profile time and energy for *every* frequency.
A good heuristic is to scan from higher frequencies to lower ones, and once energy consumption increases more than five *consecutive* times, just stop there.

### Optimize

With the CSV file that holds profiling results, you can use `lowtime` to generate all Pareto-optimal frequency plans.

See [`examples/perseus`](https://github.com/ml-energy/zeus/tree/master/examples/perseus) to find the script `run_optimization.py`.

### Choose and start training

Running `lowtime` optimization will produce a set of frequency assignment files (`freqs_pipeline_%05d.py`).
Each file is also annotated with estimates for time and cost.
The larger the number, the shorter the expected iteration time.

Then, start the Perseus server and plug in the frequency plan you chose:

```console
$ docker exec -it merak-zeus bash
# PERSEUS_SCHEDULER_ARGS='{"solution_path": "path/to/freqs_pipeline_%05d.py"}' uvicorn zeus.optimizer.perseus.server.router:app --port 7787
```

When you run training (with the same `run.sh` but without `--profile true`), the [`PerseusOptimizer`][zeus.optimizer.perseus.optimizer.PerseusOptimizer] integrated into your training framework will automatically talk with the Perseus server to figure out the right GPU frequency to set for the upcoming pipeline instruction and transparently set the GPU's frequency.
4 changes: 4 additions & 0 deletions docs/perseus/integrating.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Integrating `PerseusOptimizer`

This page is currently under construction.
We have a reference integration with [Merak](https://github.com/ml-energy/merak-zeus).
245 changes: 245 additions & 0 deletions examples/perseus/run_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (C) 2023 Jae-Won Chung <jwnchung@umich.edu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example script of running Perseus energy scheduling."""

from __future__ import annotations

import time
import itertools
import logging
from pathlib import Path
from typing import Type
from collections import defaultdict
from dataclasses import dataclass

import tyro
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

from lowtime.operation import (
CandidateExecutionOptions,
OperationSpec,
ExecutionOption,
)
from lowtime.cost_model import ExponentialModel
from lowtime.perseus.instruction import (
Instruction,
Forward,
Backward,
forward_dep,
backward_dep,
)
from lowtime.solver import PhillipsDessouky
from lowtime.graph_utils import add_source_node, add_sink_node, DependencyResolver
from lowtime.perseus.schedule import Synchronous1F1B
from lowtime.perseus.visualizer import PipelineVisualizer, ANNOTATE_ARGS, LINE_ARGS

logger = logging.getLogger()


@dataclass
class Args:
# Path to instruction profile results
inst_profile: str
# GPU power consumption while blocking on P2P communication, in Watts
p2p_power: float = 70.0
# Directory to output results
output_dir: Path
# Number of microbatchs
num_mbs: int
# Number of stages
num_stages: int
# Interval to draw the state of the pipeline
plot_interval: int = 100
# The unit of reduction for each iteration, in seconds
unit_time: float = 0.001


def main(args: Args) -> None:
"""Perseus time-cost tradeoff optimization."""
# Setup logging and output.
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=False)
log_path = output_dir / "job.log"

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
handlers=[logging.FileHandler(log_path, mode="a"), logging.StreamHandler()],
)
logger.info("Arguments: %s", args)

# Instruction offline profiling results.
inst_df = pd.read_csv(args.inst_profile)

####################
# Execution Option #
####################
# Construct the OperationSpec object of each pipeline instruction in each stage.
op_spec_map: dict[int, dict[Type[Instruction], OperationSpec]] = defaultdict(dict)
for instruction in [Forward, Backward]:
inst_name = instruction.__name__
for stage_id in range(args.num_stages):
logger.info("Processing %s stage %d", inst_name, stage_id)
options = []
_df = inst_df.query(
f"stage == {stage_id} and instruction == '{inst_name.lower()}'"
)
for _, row in _df.iterrows():
row = row.to_dict()
options.append(
ExecutionOption[int](
real_time=row["time"],
unit_time=args.unit_time,
cost=row["energy"],
knob=int(row["frequency"]),
)
)

# Get the Preto frontier, quantize time, and deduplicate time.
cand_options = CandidateExecutionOptions[int](options=options)

# Map the cost to be effective computation energy.
# Everything from now on is in terms of effective energy.
for option in cand_options.options:
option.cost -= args.p2p_power * option.quant_time * option.unit_time

# Fit the cost model.
model = ExponentialModel(cand_options)

# Draw the cost model.
fig, ax = plt.subplots(figsize=(8, 8), tight_layout=True)
model.draw(ax, cand_options)
fig.savefig(f"{output_dir}/{inst_name.lower()}_{stage_id}.png")

# Initialize the operation spec.
op_spec = OperationSpec[int](options=cand_options, cost_model=model)
op_spec_map[stage_id][instruction] = op_spec

####################
# DAG construction #
####################
dag = nx.DiGraph()

# Generate and add all instructions to the DAG.
# Reserve 0 for dummy source and 1 for dummy sink.
node_id = 2
instructions: list[list[Instruction]] = []
for stage_id in range(args.num_stages):
# Generate instructions for each stage.
stage_insts: list[Instruction] = []
stage_node_ids: list[int] = []
for inst in Synchronous1F1B(
num_stages=args.num_stages,
num_micro_batches=args.num_mbs,
stage_id=stage_id,
operation_spec_map=op_spec_map[stage_id],
):
dag.add_node(node_id, op=inst)
stage_insts.append(inst)
stage_node_ids.append(node_id)
node_id += 1
instructions.append(stage_insts)

# Add dependencies between adjacent instructions in the same stage.
for node_id1, node_id2 in zip(stage_node_ids, stage_node_ids[1:]):
dag.add_edge(node_id1, node_id2)

# Add dependencies between dependent pipeline instructions.
insts = dag.nodes(data=True)
resolver = DependencyResolver(
dependency_rules=[forward_dep, backward_dep],
node_type=Instruction,
)
for (id1, data1), (id2, data2) in itertools.product(insts, insts):
if resolver.is_dependent(data1["op"], data2["op"]):
dag.add_edge(id1, id2)

# Add source and sink nodes.
add_source_node(dag, 0)
add_sink_node(dag, 1)
dag.graph["source_node"] = 0
dag.graph["sink_node"] = 1

###################################
# Time-cost tradeoff optimization #
###################################
def annotation_hook(inst: Instruction) -> str:
return f"{type(inst).__name__[0]}\n{inst.micro_batch_id}"

def draw(dag: nx.DiGraph, iteration: int, xlim: int) -> None:
ANNOTATE_ARGS[Forward]["fontsize"] = 11.0
ANNOTATE_ARGS[Backward]["fontsize"] = 11.0
ANNOTATE_ARGS[Forward]["color"] = "black"
ANNOTATE_ARGS[Backward]["color"] = "black"
LINE_ARGS["linewidth"] = 3.0

fig, ax = plt.subplots(figsize=(args.num_mbs, 4), tight_layout=True)

vis = PipelineVisualizer(dag)
vis.draw(
ax,
draw_time_axis=True,
p2p_power=args.p2p_power,
annotation_hook=annotation_hook,
power_color="RdBu_r",
normalizer_range=(-200, 550),
)
vis.draw_critical_path(ax)

# Fix xlim so that we can visually see the pipeline width shrink.
ax.set_xlim(0.0, xlim)
ax.set_title(f"Iteration {iteration:4d}")
fig.savefig(f"{output_dir}/pipeline_{iteration:05d}.png")
plt.close(fig)

solver = PhillipsDessouky(dag)

draw_xlim = None
iteration = 0
for iteration, result in enumerate(solver.run()):
# Maybe draw the pipeline.
if iteration % args.plot_interval == 0:
if draw_xlim is None:
draw_xlim = int(result.real_time) + 1
draw(dag, iteration, draw_xlim)

# Write the frequency assignment Python file.
f = open(args.output_dir / f"freqs_pipeline_{iteration:05d}.py", "w")
f.write("[\n")
for stage_id, stage_insts in enumerate(instructions):
stage_freq: list[tuple[str, int]] = []
for inst in stage_insts:
stage_freq.append((type(inst).__name__.lower(), inst.assigned_knob))
f.write(f"{stage_freq},\n")
f.write("]\n")

iter_str = f"# Iteration {iteration} "
real_cost = result.cost + args.num_stages + result.real_time * args.p2p_power
f.write(iter_str + f"cost change: {result.cost_change}\n")
f.write(iter_str + f"total cost: {result.cost}\n")
f.write(iter_str + f"total cost with P2P: {real_cost}\n")

assert draw_xlim is not None
draw(dag, iteration, draw_xlim)


if __name__ == "__main__":
args = tyro.cli(Args)

start_time = time.time()
main(args)
logger.info("Total time: %.2fs", time.time() - start_time)
Binary file added examples/perseus/wide-resnet.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ nav:
- getting_started/index.md
- Environment Setup: getting_started/environment.md
- Installing and Building: getting_started/installing_and_building.md
- Perseus:
- perseus/index.md
- Integrating: perseus/integrating.md
- Extending Zeus: extend.md
- Source Code Reference: reference/

Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ dependencies = [
"nvidia-ml-py",
"pydantic<2",
"rich",
"tyro",
"fastapi[all]==0.87.0",
"httpx",
"aiofiles==22.1.0",
"lowtime",
]
dynamic = ["version"]

Expand All @@ -38,8 +43,7 @@ Documentation = "https://ml.energy/zeus"
[project.optional-dependencies]
lint = ["ruff", "black==22.6.0"]
test = ["pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1"]
torch = ["torch==2.0.1"]
dev = ["ruff", "black==22.6.0", "pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1", "torch==2.0.1"]
dev = ["ruff", "black==22.6.0", "pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1"]

[tool.setuptools.packages.find]
where = ["."]
Expand Down Expand Up @@ -76,6 +80,8 @@ convention = "google"

[tool.ruff.per-file-ignores]
"**/__init__.py" = ["F401", "F403"]
"zeus/optimizer/perseus/common.py" = ["N805"]
"zeus/optimizer/perseus/server/router.py" = ["B008"]

[tool.pytest.ini_options]
addopts = "--numprocesses auto"
2 changes: 1 addition & 1 deletion zeus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
- [`util`][zeus.util]: Utility functions and classes.
"""

__version__ = "0.7.1"
__version__ = "0.8.0"
Loading

0 comments on commit 076df3d

Please sign in to comment.