From 773783e6b517ed1561e6dda4111a6238ff17b835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=97=E7=A6=8F?= Date: Sat, 13 Dec 2025 23:02:24 +0800 Subject: [PATCH] support ray --- F2LLM/README.md | 5 + F2LLM/configs/config_ray.json | 21 +++ F2LLM/requirements.txt | 3 +- F2LLM/run_ray.py | 297 ++++++++++++++++++++++++++++++ F2LLM/train_with_ray.sh | 38 ++++ F2LLM/utils_ray.py | 329 ++++++++++++++++++++++++++++++++++ 6 files changed, 692 insertions(+), 1 deletion(-) create mode 100644 F2LLM/configs/config_ray.json create mode 100644 F2LLM/run_ray.py create mode 100755 F2LLM/train_with_ray.sh create mode 100644 F2LLM/utils_ray.py diff --git a/F2LLM/README.md b/F2LLM/README.md index 6b79819..1aa5194 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -38,6 +38,11 @@ For multi-node training, run on the main node: accelerate launch --config_file configs/accelerate_config.yaml --num_machines N_NODE --num_processes N_PROCESSES --machine_rank 0 --main_process_ip MASTER_IP --main_process_port MASTER_PORT run.py --config configs/config.json ``` +ray版本使用: +``` +bash train_with_ray.sh configs/config_ray.json +``` + where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is the IP address of your master node, and MASTER_PORT is a port available on your machine (e.g. 6379). On worker nodes, also run the above commmand but modify `machine_rank` accordingly. diff --git a/F2LLM/configs/config_ray.json b/F2LLM/configs/config_ray.json new file mode 100644 index 0000000..f2ab1d7 --- /dev/null +++ b/F2LLM/configs/config_ray.json @@ -0,0 +1,21 @@ +{ + "model_path": "/ossfs/workspace/CodeFuse-Embeddings-wjp/F2LLM/models/qwen3-0.6b", + "experiment_id": "ray_distributed_training", + "output_dir": "/ossfs/workspace/CodeFuse-Embeddings-wjp/F2LLM/output_8_gpu_ds", + "tb_dir": "/ossfs/workspace/CodeFuse-Embeddings-wjp/F2LLM/output_8_gpu_ds/tb", + "cache_dir": "/ossfs/workspace/CodeFuse-Embeddings-wjp/F2LLM/cache", + "train_data_path": "/ossfs/workspace/CodeFuse-Embeddings-wjp/F2LLM/data_tokenized_qwen", + "train_batch_size": 16, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 64, + "num_hard_neg": 7, + "train_steps": -1, + "train_epochs": 2, + "log_interval": 1, + "checkpointing_steps": 500, + "validation_steps": 100, + "num_processes": 8 +} diff --git a/F2LLM/requirements.txt b/F2LLM/requirements.txt index 82fb447..3214f43 100644 --- a/F2LLM/requirements.txt +++ b/F2LLM/requirements.txt @@ -3,5 +3,6 @@ datasets deepspeed flash-attn torch -transformers +transformers>=4.51.0 tensorboard +ray[train] diff --git a/F2LLM/run_ray.py b/F2LLM/run_ray.py new file mode 100644 index 0000000..57935f3 --- /dev/null +++ b/F2LLM/run_ray.py @@ -0,0 +1,297 @@ +from arguments import parse_args +from utils_ray import ray_train, CLASSIFICATION_DATASETS +from transformers import ( + AutoTokenizer, + set_seed, + get_scheduler +) +import os, json, random +from datasets import load_dataset +from torch.utils.data import DataLoader +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.optim import AdamW +from model import F2LLM +import ray +from ray import train +from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer +import deepspeed +import yaml + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def load_deepspeed_config_from_accelerate(accelerate_config_path, args, world_size): + """ + 从 accelerate_config.yaml 加载 DeepSpeed 配置并转换为 DeepSpeed 格式 + 与 run.py 中 Accelerate 的行为保持一致,只读取配置文件中实际存在的配置项 + """ + with open(accelerate_config_path, 'r') as f: + accelerate_config = yaml.safe_load(f) + + # 检查是否使用 DeepSpeed + if accelerate_config.get('distributed_type') != 'DEEPSPEED': + return None + + deepspeed_config = accelerate_config.get('deepspeed_config', {}) + + # 构建 DeepSpeed 配置 - 只包含 accelerate_config.yaml 中存在的配置项 + ds_config = { + "train_batch_size": args.train_batch_size * world_size, + "train_micro_batch_size_per_gpu": args.train_batch_size, + } + + # 从 accelerate_config.yaml 读取的配置项 + if 'gradient_accumulation_steps' in deepspeed_config: + ds_config["gradient_accumulation_steps"] = deepspeed_config['gradient_accumulation_steps'] + + if 'gradient_clipping' in deepspeed_config: + ds_config["gradient_clipping"] = deepspeed_config['gradient_clipping'] + + # ZeRO 优化配置 + if 'zero_stage' in deepspeed_config: + zero_stage = deepspeed_config['zero_stage'] + ds_config["zero_optimization"] = {"stage": zero_stage} + + # offload 配置 + if 'offload_optimizer_device' in deepspeed_config: + ds_config["zero_optimization"]["offload_optimizer"] = { + "device": deepspeed_config['offload_optimizer_device'] + } + + if 'offload_param_device' in deepspeed_config: + ds_config["zero_optimization"]["offload_param"] = { + "device": deepspeed_config['offload_param_device'] + } + + # 混合精度配置 + mixed_precision = accelerate_config.get('mixed_precision', 'no') + if mixed_precision == "fp16": + ds_config["fp16"] = {"enabled": True} + elif mixed_precision == "bf16": + ds_config["bf16"] = {"enabled": True} + + print(f"Loaded DeepSpeed config: {ds_config}") + + return ds_config + + +def _stack(input_ids, max_len): + data = [ids[:max_len] for ids in input_ids] # input_ids: list of lists + lens = [len(x) for x in data] + tensor = torch.tensor(sum(data, [])) # (total_tokens,) + return tensor.split(lens) # list of 1-d tensors + + +def get_collate_fn(tokenizer, args): + def collate_fn(batch_raw): + ''' + length of input_ids: bs * (2 + num_hard_neg) + 0 - bs-1: query input ids + bs - 2*bs-1: passage input ids + 2*bs - 2*bs+num_hard_neg-1: hard neg for sample 1 + 2*bs+num_hard_neg*(i-1) - 2*bs+num_hard_neg*i-1: hard neg for sample i (i from 1 to bs) + ''' + num_hard_neg = 1 if batch_raw[0]['dataset_name'] in CLASSIFICATION_DATASETS else args.num_hard_neg + # select args.num_hard_neg hard negatives from a total of 24 + hard_neg_indices = [0] if num_hard_neg == 1 else random.sample(list(range(24)), num_hard_neg) + input_ids = _stack( + [s['query_input_ids'] for s in batch_raw]+\ + [s['passage_input_ids'] for s in batch_raw]+\ + [s[f'negative_{i+1}_input_ids'] for s in batch_raw for i in hard_neg_indices], + args.max_seq_length + ) + seqlens = torch.tensor([ids.size(0) for ids in input_ids]) + # pad input ids to [bs, max_len] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attention_masks = input_ids.ne(tokenizer.pad_token_id).long() + + return {'input_ids': input_ids, 'seq_lens': seqlens, 'attention_mask': attention_masks, 'bs': len(batch_raw), 'dataset_name': batch_raw[0]['dataset_name']} + + return collate_fn + + +class MultiLoader: + """ + Iterates over a dict(name -> DataLoader) and returns complete batches. + At every __iter__ a new random order is created; + the epoch ends when every loader is exhausted once. + """ + def __init__(self, loader_dict): + self.loader_dict = loader_dict + + def __len__(self): + return sum(len(v) for v in self.loader_dict.values()) + + def reset_epoch(self, epoch): + self.rng = random.Random(epoch) + self.iters = {k: iter(v) for k, v in self.loader_dict.items()} + self.names = list(self.iters.keys()) + self.weights = [len(self.loader_dict[k]) for k in self.names] + + def __iter__(self): + while self.names: # until every DataLoader is empty + name = self.rng.choices(self.names, weights=self.weights)[0] # pick a data-source at random + try: + batch = next(self.iters[name]) + yield batch + except StopIteration: + idx = self.names.index(name) + self.names.pop(idx) # this dataset has no batch left + self.weights.pop(idx) + + +def train_func(config): + """Ray Train training function""" + # Get configuration + args = config["args"] + + # Set seed + set_seed(0) + + # Get world info + world_size = train.get_context().get_world_size() + world_rank = train.get_context().get_world_rank() + local_rank = train.get_context().get_local_rank() + + # Set device + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + + if world_rank == 0: + os.makedirs(f"{args.output_dir}", exist_ok=True) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # Load datasets + train_datasets, valid_datasets = [], [] + for f in sorted(os.listdir(args.train_data_path)): + dataset_name = f.split('.parquet')[0] + dataset = load_dataset("parquet", data_files=os.path.join(args.train_data_path, f), cache_dir=args.cache_dir)['train'] + dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset)) + dataset = dataset.train_test_split(train_size=0.99, shuffle=True, seed=0) + train_datasets.append((dataset_name, dataset['train'])) + valid_datasets.append((dataset_name, dataset['test'])) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + collate_fn = get_collate_fn(tokenizer, args) + + # Create data loaders + train_loaders = { + name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn) + for name, ds in train_datasets + } + valid_loaders = { + name: DataLoader(ds, shuffle=False, batch_size=args.train_batch_size, collate_fn=collate_fn) + for name, ds in valid_datasets + } + + # Prepare Ray Train DataLoader + from ray.train.torch import prepare_data_loader + for k in train_loaders.keys(): + train_loaders[k] = prepare_data_loader(train_loaders[k]) + for k in valid_loaders.keys(): + valid_loaders[k] = prepare_data_loader(valid_loaders[k]) + + train_dataloader = MultiLoader(train_loaders) + + # Determine training steps + override_train_step = False + if args.train_steps < 0: + args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs + override_train_step = True + + if world_rank == 0: + print(f"******************************** Training step before prepare: {args.train_steps} ********************************") + + # Initialize model + model = F2LLM(args.model_path, args.max_seq_length, args=args) + model.lm.gradient_checkpointing_enable() + + # Set seed again + set_seed(0) + + # Create optimizer and scheduler (like run.py) + optimizer = AdamW(model.lm.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.98)) + + lr_scheduler = get_scheduler("cosine", + optimizer=optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=args.train_steps) + + # Load DeepSpeed configuration from accelerate_config.yaml + accelerate_config_path = os.path.join(os.path.dirname(__file__), "configs", "accelerate_config.yaml") + ds_config = load_deepspeed_config_from_accelerate(accelerate_config_path, args, world_size) + + if ds_config is None: + raise ValueError("DeepSpeed configuration not found in accelerate_config.yaml") + + if world_rank == 0: + print("=" * 80) + print("DeepSpeed Configuration loaded from accelerate_config.yaml:") + print(json.dumps(ds_config, indent=2)) + print("=" * 80) + + # Initialize DeepSpeed with external optimizer and scheduler + model_engine, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model.lm, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + config=ds_config, + dist_init_required=False # Ray Train handles distributed initialization + ) + + # Update model reference + model.lm = model_engine + model.set_device() + + # Update training steps if needed + if override_train_step: + args.train_steps = len(train_dataloader) * args.train_epochs + + if world_rank == 0: + print(f"******************************** Training step after prepare: {args.train_steps} ********************************") + + # Get total number of samples + num_train_samples = sum(len(ds) for _, ds in train_datasets) + + # Start training + ray_train(args, model, train_dataloader, valid_loaders, + optimizer, lr_scheduler, num_train_samples, + world_size, world_rank, device) + + +def main(): + args = parse_args() + + # Initialize Ray + if not ray.is_initialized(): + ray.init() + + # Configure scaling + scaling_config = ScalingConfig( + num_workers=args.num_processes if hasattr(args, 'num_processes') and args.num_processes > 0 else 2, + use_gpu=True, + resources_per_worker={"CPU": 8, "GPU": 1}, + ) + + # Create Ray TorchTrainer + trainer = TorchTrainer( + train_loop_per_worker=train_func, + train_loop_config={"args": args}, + scaling_config=scaling_config, + ) + + # Start training + result = trainer.fit() + + print("Training completed!") + print(f"Results: {result}") + + +if __name__ == "__main__": + main() diff --git a/F2LLM/train_with_ray.sh b/F2LLM/train_with_ray.sh new file mode 100755 index 0000000..e990f75 --- /dev/null +++ b/F2LLM/train_with_ray.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# F2LLM Ray 分布式训练启动脚本 + +# 设置环境变量 +export TOKENIZERS_PARALLELISM=false +export NCCL_DEBUG=INFO +export NCCL_IB_DISABLE=1 # 如果不使用 InfiniBand + +# 默认配置文件 +CONFIG_FILE=${1:-"configs/config_ray.json"} + +echo "==========================================" +echo "F2LLM Ray 分布式训练" +echo "==========================================" +echo "配置文件: $CONFIG_FILE" +echo "==========================================" + +# 检查配置文件是否存在 +if [ ! -f "$CONFIG_FILE" ]; then + echo "错误: 配置文件 $CONFIG_FILE 不存在!" + exit 1 +fi + +# 检查 Ray 是否安装 +python -c "import ray" 2>/dev/null +if [ $? -ne 0 ]; then + echo "错误: Ray 未安装,请运行: pip install ray[train]" + exit 1 +fi + +# 启动训练 +echo "开始训练..." +python run_ray.py --config "$CONFIG_FILE" + +echo "==========================================" +echo "训练完成!" +echo "==========================================" diff --git a/F2LLM/utils_ray.py b/F2LLM/utils_ray.py new file mode 100644 index 0000000..00e2530 --- /dev/null +++ b/F2LLM/utils_ray.py @@ -0,0 +1,329 @@ +from tqdm.auto import tqdm +from torch.utils.tensorboard import SummaryWriter +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +import os +import torch.distributed as dist +from ray import train + +CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola'] +CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups'] +RETRIEVAL_DATASETS = ['arguana', 'snli', 'mnli', 'anli', 'paq', 'squad', 'stackexchange', 'msmarco', 'natural_questions', 'hotpotqa', 'fever', 'eli5', 'fiqa', 'bioasq', 'nfcorpus', 'miracl', 'mrtidy', 'scifact', 'qqp', 'stackoverflowdupquestions', 'sts12', 'sts22', 'stsbenchmark', 'amazon_qa', 'cnn_dm', 'coliee', 'paq_part2', 'pubmedqa', 's2orc_abstract_citation', 's2orc_title_abstract', 's2orc_title_citation', 'sentence_compression', 'specter', 'triviaqa', 'xsum', 'stackexchange_part2', 'stackexchangedupquestions_s2s', 'stackexchangedupquestions_p2p'] + + +def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): + for key, value in log_dict.items(): + summary_writer.add_scalar(key, value, completed_steps) + + +def save_checkpoint(args, model, output_dir, lr_scheduler, world_rank): + """Save checkpoint using Ray Train""" + # Wait for all processes + if dist.is_initialized(): + dist.barrier() + + if world_rank == 0: + print(f"Saving checkpoint to {output_dir}") + os.makedirs(output_dir, exist_ok=True) + + if world_rank == 0: + model.tokenizer.save_pretrained(output_dir) + + # Check if using DeepSpeed + if hasattr(model.lm, 'save_checkpoint'): + # DeepSpeed model + model.lm.save_checkpoint(output_dir) + else: + # Unwrap DDP model + unwrapped_model = model.lm.module if hasattr(model.lm, 'module') else model.lm + + if world_rank == 0: + unwrapped_model.save_pretrained(output_dir) + + # Wait for all processes + if dist.is_initialized(): + dist.barrier() + + +def all_gather_tensor(tensor, world_size): + """Gather tensors from all processes""" + if not dist.is_initialized() or world_size == 1: + return tensor + + tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor) + return torch.cat(tensor_list, dim=0) + + +def inbatch_loss( + query_embeddings, # [bs, d] + context_embeddings, # [bs, d] + criterion, + world_size, + world_rank, + temperature=0.05, + ): + + bs = query_embeddings.size(0) + a_norm = F.normalize(query_embeddings, p=2, dim=-1) + + # Gather context embeddings from all processes + b_cross_gpus = all_gather_tensor(context_embeddings, world_size) # [bs*world_size, d] + b_norm_cross_gpus = F.normalize(b_cross_gpus, p=2, dim=-1) + + student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*world_size] + + labels = torch.arange(bs, device=student_logits.device) + bs * world_rank + loss_bs = criterion(student_logits, labels) # (bs) + + loss = loss_bs.mean() + + return loss + + +def hard_loss( + query_embeddings, # [bs, d] + context_embeddings, # [bs, d] + hard_neg_embeddings, # [bs, num, d] + criterion, + temperature=0.05, + ): + + if hard_neg_embeddings is None: + return 0.0 + + bs = query_embeddings.size(0) + a_norm = F.normalize(query_embeddings, p=2, dim=-1) + + hard_neg_embeddings = torch.concat([ + context_embeddings.unsqueeze(1), + hard_neg_embeddings + ], dim=1) # [bs, num_hard+1, d] + + hard_norm = F.normalize(hard_neg_embeddings, p=2, dim=-1) + logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1] + + loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() + + return loss_hard + + +def all_reduce_mean(tensor, world_size): + """All reduce and compute mean across processes""" + if not dist.is_initialized() or world_size == 1: + return tensor + + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor / world_size + + +def validate(args, model, valid_loader_dict, criterion, completed_steps, summary_writer, world_size, world_rank): + eval_log_dict = {} + for dataset_name, valid_dataloader in valid_loader_dict.items(): + loss_ls, loss_hard_ls = [], [] + for batch in valid_dataloader: + with torch.no_grad(): + outputs = model.forward(batch) + loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion) + + # Gather loss from all processes + loss_hard_gathered = all_gather_tensor(loss_hard.unsqueeze(0), world_size) + loss_hard_ls.append(loss_hard_gathered.float()) + + if dataset_name in RETRIEVAL_DATASETS: + loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion, world_size, world_rank) + loss_gathered = all_gather_tensor(loss.unsqueeze(0), world_size) + loss_ls.append(loss_gathered.float()) + + # Wait for all processes + if dist.is_initialized(): + dist.barrier() + + loss_hard_ls = torch.cat(loss_hard_ls) + eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() + if dataset_name in RETRIEVAL_DATASETS: + loss_ls = torch.cat(loss_ls) + eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() + + eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean() + eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean() + eval_log_dict['Avg/classification/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() + eval_log_dict['Avg/clustering/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() + + if world_rank == 0: + write_tensorboard(summary_writer, eval_log_dict, completed_steps) + print(f"[Validation] Step = {completed_steps}") + + # Report metrics to Ray Train + eval_report_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in eval_log_dict.items()} + train.report(eval_report_dict) + + +def ray_train(args, + model, + train_dataloader, + valid_loader_dict, + optimizer, + lr_scheduler, + num_train_samples, + world_size, + world_rank, + device): + + if world_rank == 0: + print("**************************************** Start training ****************************************") + print(f" Num train samples = {num_train_samples}") + print(f" Num epochs = {args.train_epochs}") + print(f" Per device batch size = {args.train_batch_size}") + print(f" Global batch size = {args.train_batch_size * world_size}") + print(f" Step per epoch = {len(train_dataloader)}") + print(f" Total training steps = {args.train_steps}") + print("************************************************************************************************") + + global RETRIEVAL_DATASETS, CLASSIFICATION_DATASETS, CLUSTERING_DATASETS + RETRIEVAL_DATASETS = [ds for ds in RETRIEVAL_DATASETS if ds in train_dataloader.loader_dict.keys()] + CLASSIFICATION_DATASETS = [ds for ds in CLASSIFICATION_DATASETS if ds in train_dataloader.loader_dict.keys()] + CLUSTERING_DATASETS = [ds for ds in CLUSTERING_DATASETS if ds in train_dataloader.loader_dict.keys()] + + summary_writer = SummaryWriter(log_dir=args.tb_dir) if world_rank == 0 else None + criterion = CrossEntropyLoss(reduction='none') + pbar = tqdm(range(args.train_steps), disable=(world_rank != 0)) + completed_steps = 0 + + # Check if using DeepSpeed + is_deepspeed = hasattr(model.lm, 'backward') + + loss_dict = {ds_name: torch.tensor(0.0, device=device) for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=device) for ds_name in train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=device) for ds_name in RETRIEVAL_DATASETS} + count_hard_dict = {ds_name: torch.tensor(0, device=device) for ds_name in train_dataloader.loader_dict.keys()} + + model.lm.train() + for epoch in range(args.train_epochs): + if world_rank == 0: + print(f"*************** Starting epoch {epoch+1} ***************") + + train_dataloader.reset_epoch(epoch) + for batch in train_dataloader: + # Move batch to device + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + # forward and compute loss + outputs = model.forward(batch) + + loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion) + + dataset_name = batch['dataset_name'] + count_hard_dict[dataset_name] += 1 + loss_hard_dict[dataset_name] += loss_hard.detach().float() + + if dataset_name in RETRIEVAL_DATASETS: + loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion, world_size, world_rank) + count_dict[dataset_name] += 1 + loss_dict[dataset_name] += loss.detach().float() + else: + loss = 0.0 + + loss_total = loss + loss_hard + + # backward, optimizer, scheduler - DeepSpeed compatible + if is_deepspeed: + # DeepSpeed handles backward, optimizer step, and scheduler step + model.lm.backward(loss_total) + model.lm.step() + else: + # Standard PyTorch training + optimizer.zero_grad() + loss_total.backward() + optimizer.step() + lr_scheduler.step() + + # Get current learning rate + if is_deepspeed: + current_lr = model.lm.get_lr()[0] + else: + current_lr = optimizer.param_groups[0]['lr'] + + # Apply minimum learning rate constraint + if current_lr < args.min_lr: + if is_deepspeed: + # For DeepSpeed, we need to update the lr in the config + for param_group in model.lm.optimizer.param_groups: + param_group['lr'] = args.min_lr + else: + for i in range(len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = args.min_lr + + # log + completed_steps += 1 + if completed_steps % args.log_interval == 0: + pbar.update(args.log_interval) + + train_log_dict = {"lr": current_lr} + + for k in loss_dict.keys(): + count = all_reduce_mean(count_dict[k].clone(), world_size) + if count > 0: + loss_sum = all_reduce_mean(loss_dict[k].clone(), world_size) + train_log_dict[f"{k}/training_loss_in_batch"] = loss_sum / count + + for k in loss_hard_dict.keys(): + count = all_reduce_mean(count_hard_dict[k].clone(), world_size) + if count > 0: + loss_sum = all_reduce_mean(loss_hard_dict[k].clone(), world_size) + train_log_dict[f"{k}/training_loss_hard"] = loss_sum / count + + train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() + train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() + train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() + train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() + + if world_rank == 0: + print(f"[Train] Step = {completed_steps}") + write_tensorboard(summary_writer, train_log_dict, completed_steps) + + # Report metrics to Ray Train + report_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in train_log_dict.items()} + train.report(report_dict) + + loss_dict = {ds_name: torch.tensor(0.0, device=device) for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=device) for ds_name in train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=device) for ds_name in RETRIEVAL_DATASETS} + count_hard_dict = {ds_name: torch.tensor(0, device=device) for ds_name in train_dataloader.loader_dict.keys()} + + # validation + if completed_steps % args.validation_steps == 0: + model.lm.eval() + validate(args, model, valid_loader_dict, criterion, completed_steps, summary_writer, world_size, world_rank) + model.lm.train() + + # step checkpoint + if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: + output_dir = os.path.join(args.output_dir, f"step_{completed_steps}") + save_checkpoint(args, model, output_dir, lr_scheduler, world_rank) + + if completed_steps >= args.train_steps: + break + + # epoch checkpoint + output_dir = os.path.join(args.output_dir, f"epoch_{epoch+1}") + save_checkpoint(args, model, output_dir, lr_scheduler, world_rank) + + if completed_steps % args.validation_steps != 0: + model.lm.eval() + validate(args, model, valid_loader_dict, criterion, completed_steps, summary_writer, world_size, world_rank) + model.lm.train() + + if summary_writer: + summary_writer.close()