Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
Empty file.
Empty file.
14 changes: 12 additions & 2 deletions applications/DeepSpeed-Chat/dschat/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def get_train_ds_config(offload,
enable_tensorboard=False,
enable_mixed_precision_lora=False,
tb_path="",
tb_name=""):
tb_name="",
offload_optimizer_config=None,
offload_param_config=None,
aio_config=None):

device = "cpu" if offload else "none"
if dtype == "fp16":
Expand All @@ -45,12 +48,16 @@ def get_train_ds_config(offload,
"stage3_prefetch_bucket_size": 3e7,
"memory_efficient_linear": False
}
if offload_optimizer_config:
zero_opt_dict["offload_optimizer"].update(offload_optimizer_config)
if offload_param_config:
zero_opt_dict["offload_param"].update(offload_param_config)
if enable_mixed_precision_lora:
zero_opt_dict["zero_quantized_nontrainable_weights"] = True
if dist.get_world_size() != get_accelerator().device_count():
zero_opt_dict["zero_hpz_partition_size"] = get_accelerator(
).device_count()
return {
config = {
"train_batch_size": GLOBAL_BATCH_SIZE,
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
Expand All @@ -73,6 +80,9 @@ def get_train_ds_config(offload,
"job_name": f"{tb_name}_tensorboard"
}
}
if aio_config:
config["aio"] = aio_config
return config


def get_eval_ds_config(offload, dtype, stage=0):
Expand Down
Empty file.
Empty file.
18 changes: 7 additions & 11 deletions applications/DeepSpeed-Chat/dschat/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def get(self):

def get_tokenizer(model_name_or_path, fast_tokenizer=True):
if "llama" in model_name_or_path:
from transformers.models.llama import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, fast_tokenizer=fast_tokenizer)
if tokenizer.pad_token is None:
# assert tokenizer.eos_token is not None
Expand All @@ -94,16 +93,13 @@ def get_tokenizer(model_name_or_path, fast_tokenizer=True):
def load_hf_tokenizer(model_name_or_path,
fast_tokenizer=True,
add_special_tokens=None):
if os.path.exists(model_name_or_path):
# Locally tokenizer loading has some issue, so we need to force download
model_json = os.path.join(model_name_or_path, "config.json")
if os.path.exists(model_json):
model_json_file = json.load(open(model_json))
model_name = model_json_file.get("_name_or_path",
model_name_or_path)
tokenizer = get_tokenizer(model_name,
fast_tokenizer=fast_tokenizer)
# Support loading from local path directly
if os.path.exists(model_name_or_path) and os.path.isdir(model_name_or_path):
# Directly load tokenizer from local path
tokenizer = get_tokenizer(model_name_or_path,
fast_tokenizer=fast_tokenizer)
else:
# Load from HuggingFace Hub or use original logic
tokenizer = get_tokenizer(model_name_or_path,
fast_tokenizer=fast_tokenizer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# DeepSpeed Team
import argparse
import math
import os
from pprint import pformat

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
Expand All @@ -29,6 +31,18 @@
from dschat.utils.perf import print_throughput


def str2bool(value):
if isinstance(value, bool):
return value
lowered = value.lower()
if lowered in ("yes", "true", "t", "1"):
return True
if lowered in ("no", "false", "f", "0"):
return False
raise argparse.ArgumentTypeError(
f"Boolean value expected, got `{value}`.")


def parse_args():
parser = argparse.ArgumentParser(
description=
Expand Down Expand Up @@ -145,6 +159,80 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--offload_optimizer_device',
type=str,
choices=['cpu', 'nvme'],
default=None,
help='Device to use for ZeRO optimizer state offload.')
parser.add_argument('--offload_optimizer_nvme_path',
type=str,
default=None,
help='NVMe path used when offloading optimizer states to nvme.')
parser.add_argument('--offload_optimizer_pin_memory',
type=str2bool,
default=None,
help='Whether to pin optimizer offload memory (true|false).')
parser.add_argument('--offload_optimizer_ratio',
type=float,
default=None,
help='Ratio of optimizer state to keep on device when offloading.')
parser.add_argument('--offload_optimizer_buffer_count',
type=int,
default=None,
help='Number of optimizer offload buffers.')
parser.add_argument('--offload_optimizer_fast_init',
type=str2bool,
default=None,
help='Use fast init for optimizer offload buffers (true|false).')
parser.add_argument('--offload_param_device',
type=str,
choices=['cpu', 'nvme'],
default=None,
help='Device to use for ZeRO parameter offload.')
parser.add_argument('--offload_param_nvme_path',
type=str,
default=None,
help='NVMe path used when offloading parameters to nvme.')
parser.add_argument('--offload_param_pin_memory',
type=str2bool,
default=None,
help='Whether to pin parameter offload memory (true|false).')
parser.add_argument('--offload_param_buffer_size',
type=int,
default=None,
help='Parameter offload buffer size (number of elements). Increase if embedding layer is larger than the default.')
parser.add_argument('--offload_param_buffer_count',
type=int,
default=None,
help='Number of parameter offload buffers.')
parser.add_argument('--offload_param_max_in_cpu',
type=float,
default=None,
help='Maximum number of parameters to keep in CPU memory during offload.')
parser.add_argument('--aio_block_size',
type=int,
default=1048576,
help='AIO block size for NVMe offload (bytes).')
parser.add_argument('--aio_queue_depth',
type=int,
default=8,
help='AIO queue depth for NVMe offload.')
parser.add_argument('--aio_intra_op_parallelism',
type=int,
default=1,
help='AIO intra_op_parallelism for NVMe offload.')
parser.add_argument('--aio_single_submit',
type=str2bool,
default=False,
help='AIO single_submit flag.')
parser.add_argument('--aio_overlap_events',
type=str2bool,
default=True,
help='AIO overlap_events flag.')
parser.add_argument('--aio_use_gds',
type=str2bool,
default=False,
help='AIO use_gds flag.')
parser.add_argument('--dtype',
type=str,
default='fp16',
Expand Down Expand Up @@ -222,18 +310,91 @@ def main():

args.global_rank = torch.distributed.get_rank()

# 根据 local_rank 动态设置 NVMe 路径
# 如果命令行参数中已经指定了路径,则使用命令行参数;否则根据 local_rank 设置
local_rank = args.local_rank if args.local_rank != -1 else 0

# 支持通过环境变量配置 NVMe 路径列表(用冒号分隔)
# 例如:export NVME_PATHS="/mnt/deepspeed_nvme0:/mnt/deepspeed_nvme1"
nvme_paths_env = os.environ.get('NVME_PATHS', '')
if nvme_paths_env:
nvme_paths = [path.strip() for path in nvme_paths_env.split(':') if path.strip()]
if local_rank < len(nvme_paths):
default_nvme_path = nvme_paths[local_rank]
else:
default_nvme_path = nvme_paths[0] if nvme_paths else None
else:
# 默认映射:GPU 0 -> /mnt/deepspeed_nvme0, GPU 1 -> /mnt/deepspeed_nvme1, 以此类推
default_nvme_path = f"/mnt/deepspeed_nvme{local_rank}"

# 如果命令行参数中没有指定 optimizer nvme_path,则使用根据 local_rank 确定的路径
if args.offload_optimizer_nvme_path is None and default_nvme_path:
args.offload_optimizer_nvme_path = default_nvme_path
print_rank_0(f"Rank {args.global_rank} (local_rank {local_rank}) using optimizer NVMe path: {args.offload_optimizer_nvme_path}", args.global_rank)

# 如果命令行参数中没有指定 param nvme_path,则使用根据 local_rank 确定的路径
if args.offload_param_nvme_path is None and default_nvme_path:
args.offload_param_nvme_path = default_nvme_path
print_rank_0(f"Rank {args.global_rank} (local_rank {local_rank}) using param NVMe path: {args.offload_param_nvme_path}", args.global_rank)

offload_optimizer_overrides = {
"device": args.offload_optimizer_device,
"nvme_path": args.offload_optimizer_nvme_path,
"pin_memory": args.offload_optimizer_pin_memory,
"ratio": args.offload_optimizer_ratio,
"buffer_count": args.offload_optimizer_buffer_count,
"fast_init": args.offload_optimizer_fast_init
}
offload_optimizer_overrides = {
key: value
for key, value in offload_optimizer_overrides.items()
if value is not None
}
offload_param_overrides = {
"device": args.offload_param_device,
"nvme_path": args.offload_param_nvme_path,
"pin_memory": args.offload_param_pin_memory,
"buffer_size": args.offload_param_buffer_size,
"buffer_count": args.offload_param_buffer_count,
"max_in_cpu": args.offload_param_max_in_cpu
}
offload_param_overrides = {
key: value
for key, value in offload_param_overrides.items()
if value is not None
}
aio_config = {
"block_size": args.aio_block_size,
"queue_depth": args.aio_queue_depth,
"intra_op_parallelism": args.aio_intra_op_parallelism,
"single_submit": args.aio_single_submit,
"overlap_events": args.aio_overlap_events,
"use_gds": args.aio_use_gds,
}
ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
tb_name="step1_model")
tb_name="step1_model",
offload_optimizer_config=(
offload_optimizer_overrides
if offload_optimizer_overrides else None),
offload_param_config=(
offload_param_overrides
if offload_param_overrides else None),
aio_config=aio_config)
ds_config[
'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
ds_config[
'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size(
) * args.gradient_accumulation_steps


# It seems that ds_config is completed here, so we print configuration here
print_rank_0("***** DeepSpeed User Provided config *****", args.global_rank)
print_rank_0(pformat(ds_config), args.global_rank)

# If passed along, set the training seed now.
set_random_seed(args.seed)

Expand All @@ -245,6 +406,9 @@ def main():
fast_tokenizer=True,
add_special_tokens=additional_special_tokens)

print_rank_0("***** Tokenizer *****", args.global_rank)
print_rank_0(tokenizer, args.global_rank)

model = create_hf_model(AutoModelForCausalLM,
args.model_name_or_path,
tokenizer,
Expand All @@ -264,6 +428,10 @@ def main():
model = only_optimize_lora_parameters(model)
model = make_model_gradient_checkpointing_compatible(model)

# Print full model architecture (rank 0 only to avoid log spam)
print_rank_0("***** Model architecture *****", args.global_rank)
print_rank_0(model, args.global_rank)

# Prepare the data
train_phase = 1
train_dataset, eval_dataset = create_prompt_dataset(
Expand Down Expand Up @@ -319,6 +487,7 @@ def evaluation(model, eval_dataloader):
model, args.weight_decay, args.lora_learning_rate)

AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
print_rank_0(f"offload: {args.offload}", args.global_rank)
optimizer = AdamOptimizer(optimizer_grouped_parameters,
lr=args.learning_rate,
betas=(0.9, 0.95))
Expand Down Expand Up @@ -348,8 +517,9 @@ def evaluation(model, eval_dataloader):
print_rank_0(
f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****",
args.global_rank)
perplexity, eval_loss = evaluation(model, eval_dataloader)
print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
print_rank_0("Jump Evaluation", args.global_rank)
# perplexity, eval_loss = evaluation(model, eval_dataloader)
# print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)

for epoch in range(args.num_train_epochs):
print_rank_0(
Expand All @@ -372,6 +542,11 @@ def evaluation(model, eval_dataloader):
if torch.distributed.get_rank() == 0:
print_throughput(model.model, args, end - start,
args.global_rank)

# return for debugging
if step > 5:
return 0


# Evaluate perplexity on the validation set.
print_rank_0(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

source /home/ckr/miniconda3/etc/profile.d/conda.sh
conda activate ds

export http_proxy=http://127.0.0.1:7890
export https_proxy=http://127.0.0.1:7890

# DeepSpeed Team
OUTPUT=$1
ZERO_STAGE=$2
if [ "$OUTPUT" == "" ]; then
OUTPUT=./output_result/test_output
fi
if [ "$ZERO_STAGE" == "" ]; then
ZERO_STAGE=3
fi
mkdir -p $OUTPUT

# 设置 Hugging Face token(从环境变量读取,如果未设置则使用默认值)
# 建议在 ~/.bashrc 或 ~/.zshrc 中设置:export HF_TOKEN="your_token_here"
if [ -z "$HF_TOKEN" ]; then
echo "Warning: HF_TOKEN environment variable is not set. Please set it before running this script."
echo "You can set it by: export HF_TOKEN='your_token_here'"
fi
export HUGGING_FACE_HUB_TOKEN="${HF_TOKEN:-}"

# 可选:如果网络连接有问题,可以设置镜像(取消下面的注释)
export HF_ENDPOINT="https://hf-mirror.com"

# --offload_param_nvme_path /mnt/raid0 \
CUDA_VISIBLE_DEVICES=0,1 deepspeed --master_port=29603 main.py \
--aio_block_size=16777216 \
--offload \
--offload_optimizer_device nvme \
--offload_optimizer_nvme_path /mnt/raid0 \
--offload_optimizer_pin_memory true \
--offload_optimizer_ratio 0.3 \
--offload_optimizer_buffer_count 8 \
--offload_optimizer_fast_init false \
--offload_param_device nvme \
--offload_param_pin_memory true \
--offload_param_buffer_size 360349696 \
--offload_param_buffer_count 32 \
--offload_param_max_in_cpu 0 \
--aio_use_gds true \
--dtype bf16 \
--data_path Dahoas/rm-static \
--data_split 2,4,4 \
--model_name_or_path /home/ckr/LoRA/models/70B \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--max_seq_len 512 \
--learning_rate 9.65e-6 \
--weight_decay 0. \
--num_train_epochs 4 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--num_warmup_steps 0 \
--seed 1234 \
--gradient_checkpointing \
--zero_stage $ZERO_STAGE \
--deepspeed \
--lora_dim 128 \
--lora_module_name "layers." \
--data_output_path ./data \
--output_dir $OUTPUT \
&> $OUTPUT/training.log

Loading