Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unbalanced param_sync example. #126

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/megatron/configs/llama2/rlhf_param_sync.yaml
haolin-nju marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
runtime_env:
platform: DLC
excludes:
- "*pt"
- "logs"
- "tensorboards"
- ".nfs*"


models:
policy:
model_config_file: old_policy_inference.yaml
num_gpu: ${num_gpu_policy:16}
trainable: False
batch_generation:
ranking: ${batch_generation_ranking:False}
min_prompt_length: ${batch_generation_min_prompt_length:0}
free_memory: ${free_memory_policy:False}

ppo_policy:
model_config_file: ppo_policy.yaml
num_gpu: ${num_gpu_ppo_policy:16}
trainable: True
lora:
enable_lora: ${enable_lora_policy:False}
lora_dim: 64
lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
column_only_qkv: False
lora_dropout: 0.05
free_memory: ${free_memory_ppo_policy:False}

runtime:
colocation:
- policy,ppo_policy
generation_batch_size: ${generation_batch_size:4}
train_micro_batch_size: ${train_micro_batch_size:2}
train_global_batch_size: ${train_global_batch_size:512}
num_episode: ${num_episode:200}
sample_per_episode: ${sample_per_episode:1024}
num_training_epoch: 1
save_episode_interval: ${save_episode_interval:50}
data_path: ${data_path}
eval_data_path: ${eval_data_path}
training_data_num_limit: ${training_data_num_limit:-1}
eval_data_num_limit: ${eval_data_num_limit:128}
eval_episode_interval: ${eval_episode_interval:100}
data_checkpoint_path: ${data_checkpoint_path}
output_dir: ${output_dir}
exp_name: ${exp_name:chatlearn}
49 changes: 49 additions & 0 deletions examples/megatron/configs/llama2/vllm_param_sync.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
runtime_env:
platform: DLC
excludes:
- "*pt"
- "logs"
- "tensorboards"
- ".nfs*"


models:
policy:
model_config_file: vllm_policy_inference.yaml
num_gpu: ${num_gpu_policy:16}
trainable: False
batch_generation:
ranking: ${batch_generation_ranking:False}
min_prompt_length: ${batch_generation_min_prompt_length:0}
free_memory: ${free_memory_policy:False}

ppo_policy:
model_config_file: ppo_policy.yaml
num_gpu: ${num_gpu_ppo_policy:16}
trainable: True
lora:
enable_lora: ${enable_lora_policy:False}
lora_dim: 64
lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
column_only_qkv: False
lora_dropout: 0.05
free_memory: ${free_memory_ppo_policy:False}

runtime:
colocation:
- policy,ppo_policy
generation_batch_size: ${generation_batch_size:4}
train_micro_batch_size: ${train_micro_batch_size:2}
train_global_batch_size: ${train_global_batch_size:512}
num_episode: ${num_episode:200}
sample_per_episode: ${sample_per_episode:1024}
num_training_epoch: 1
save_episode_interval: ${save_episode_interval:50}
data_path: ${data_path}
eval_data_path: ${eval_data_path}
training_data_num_limit: ${training_data_num_limit:-1}
eval_data_num_limit: ${eval_data_num_limit:128}
eval_episode_interval: ${eval_episode_interval:100}
data_checkpoint_path: ${data_checkpoint_path}
output_dir: ${output_dir}
exp_name: ${exp_name:chatlearn}
65 changes: 65 additions & 0 deletions examples/megatron/tests/test_unbalanced_param_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""entry file for training RLHF"""

import os

from examples.megatron.models import PolicyTrainer
from examples.megatron.models.train_helper import get_prompts

import chatlearn
from chatlearn.models.base_module import BaseModule
from chatlearn.runtime.engine import Engine
from chatlearn.runtime.environment import Environment
from chatlearn.runtime.trainer import Trainer

# pylint: disable=invalid-envvar-default,bad-exception-cause,ungrouped-imports
if os.getenv("ENABLE_VLLM", False):
try:
from examples.megatron.models import VLLMPolicyInference as PolicyModel
except Exception as e:
raise RuntimeError("Cannot import vllm, please set vllm python path or install vllm first.") from e
else:
from examples.megatron.models import PolicyInference as PolicyModel


class CustomEngine(Engine):
"""Custom engine for param sync from ppo_policy to policy."""
def __init__(self,
policy: BaseModule,
policy_trainer: BaseModule):
def env_compute_flow(batch):
policy_out = policy.forward_step(batch)
return policy_out

def trainer_compute_flow(batch):
policy_trainer.train_step(batch)

env = Environment(env_compute_flow)
trainer = Trainer(trainer_compute_flow)
super().__init__(env, trainer, name='ParamSync')
self.set_parameter_sync(policy_trainer, policy)


if __name__ == "__main__":
chatlearn.init()
args = chatlearn.get_args()
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
ppo_policy = PolicyTrainer("ppo_policy")
policy_model = PolicyModel("policy")

engine = CustomEngine(policy_model, ppo_policy)
train_prompts = get_prompts(args.runtime_args.data_path, num_limit=args.runtime_args._args_dict['training_data_num_limit'])
engine.set_dataset(train_prompts)
engine.learn()
58 changes: 58 additions & 0 deletions examples/megatron/tests/test_unbalanced_param_sync.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/bin/bash
set -x


[ -z "$model_size" ] && export model_size=llama2-7B

source scripts/base_env.sh

backend=${1:-vllm}

[ -z "$max_new_tokens" ] && export max_new_tokens=512
[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}
[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/
[ -z "$DATA_DIR" ] && DATA_DIR=${output_dir}/gpt/
output_dir=${output_dir}/${exp_name}
export data_checkpoint_path=${output_dir}/data_checkpoint

mkdir -p $output_dir

export max_seq_len=$(( max_new_tokens*2 ))

config_dir=${CHATLEARN}/examples/megatron/configs/


if [[ "$model_size" == "llama2-7B" ]]; then
export policy_tp=8
export policy_pp=1
export ppo_policy_tp=2
export ppo_policy_pp=4
export train_global_batch_size=128
if [[ "$backend" == "megatron" ]]; then
export generation_batch_size=128
config_file=${config_dir}/llama2/rlhf_param_sync.yaml
elif [[ "$backend" == "vllm" ]]; then
export ENABLE_VLLM=True
export generation_batch_size=128
config_file=${config_dir}/llama2/vllm_param_sync.yaml
fi
export train_micro_batch_size=16
export max_num_batched_tokens=65536
export gpu_memory_utilization=0.8

export num_gpu_policy=8
export num_gpu_ppo_policy=8
export free_memory_policy=True
export free_memory_ppo_policy=True
fi

policy_inference_load=${POLICY_LOAD} \
reward_load_iteration=${REWARD_LOAD_ITERATION} \
reward_load=${REWARD_LOAD} \
tokenizer_model=${TOKENIZER_MODEL} \
num_episode=${num_ppo_episode:-0} \
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
data_path=${DATASET_PATH} \
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
eval_data_path=${EVAL_DATASET_PATH} \
sample_per_episode=${sample_per_episode} \
tensorboard_dir=${TENSORBOARD_DIR} \
python tests/test_unbalanced_param_sync.py -c $config_file 2>&1 | tee ${output_dir}/log_${RANK}.log ; exit ${PIPESTATUS[0]}
Loading