diff --git a/.idea/CodeFuse-Embeddings.iml b/.idea/CodeFuse-Embeddings.iml new file mode 100644 index 0000000..d6ebd48 --- /dev/null +++ b/.idea/CodeFuse-Embeddings.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..f03c948 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..698e3e6 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000..c07f011 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + { + "associatedIndex": 8 +} + + + + { + "keyToString": { + "RunOnceActivity.ShowReadmeOnStart": "true", + "RunOnceActivity.git.unshallow": "true", + "git-widget-placeholder": "gradient__accumulation__1208", + "kotlin-language-version-configured": "true", + "last_opened_file_path": "/Users/limfluoryynx/CodeFuse-Embeddings", + "settings.editor.selected.configurable": "MavenSettings" + } +} + + + + + 1765178240734 + + + + \ No newline at end of file diff --git a/F2LLM/GRADIENT_ACCUMULATION_README.md b/F2LLM/GRADIENT_ACCUMULATION_README.md new file mode 100644 index 0000000..3f43124 --- /dev/null +++ b/F2LLM/GRADIENT_ACCUMULATION_README.md @@ -0,0 +1,53 @@ +# Gradient Accumulation in F2LLM + +## How Gradient Accumulation Works in This Codebase + +1. Set `gradient_accumulation_steps` in the config.json and arguments.py file (default is 1, meaning no accumulation) + - e.g: `"gradient_accumulation_steps": 4` will accumulate gradients over 4 micro-batches + + +2. `utils.py`: + ```python + # Scale loss by gradient accumulation steps to maintain same effective learning rate + loss_total = loss_total / args.gradient_accumulation_steps + + # Update step only after gradient_accumulation_steps + if (completed_steps + 1) % args.gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + ``` + - Without accumulation: Process 1 batch of size N → compute loss → update parameters + - With accumulation: Process 4 micro-batches of size N/4 → accumulate gradients → update parameters + + Both result in same parameter update if learning rate is properly scaled + + +## Example + +Let's say you have: +- Desired effective batch size: 32 +- GPU memory only allows: 8 samples per batch + +**Without Gradient Accumulation**: +- You're limited to batch size 8 +- Effective batch size = 8 +- May result in suboptimal training dynamics + +**With Gradient Accumulation (steps=4)**: +- Process 4 micro-batches of size 8 each +- Effective batch size = 32 (4 × 8) +- Same training dynamics as a batch size of 32 +- Better gradient estimates due to larger effective batch size + +## Configuration Example + +To use gradient accumulation, modify your config file: +```json +{ + "train_batch_size": 8, + "gradient_accumulation_steps": 4, + // This gives you an effective batch size of 32 (8 * 4) + // while only using memory for 8 samples at a time +} +``` \ No newline at end of file diff --git a/F2LLM/RAY_TRAINING.md b/F2LLM/RAY_TRAINING.md new file mode 100644 index 0000000..563f3a4 --- /dev/null +++ b/F2LLM/RAY_TRAINING.md @@ -0,0 +1,39 @@ +## Ray Distributed Training + +This directory contains the Ray-based distributed training implementation for F2LLM embedding models, providing scalable, fault-tolerant training capabilities with automatic resource management and seamless scaling from single-node to multi-node clusters. + +### Usage + +#### Single-Node Training +```bash +python ray_distributed_run.py --config configs/ray_config.json --num_workers 4 --num_gpus_per_worker 1.0 +``` + +#### Multi-Node Training + +1. On the head node: +```bash +ray start --head --port=6379 +python ray_distributed_run.py --config configs/ray_config.json --num_workers 8 --num_gpus_per_worker 1.0 --ray_head_address HEAD_NODE_IP +``` + +2. On worker nodes: +```bash +ray start --address=HEAD_NODE_IP:6379 +``` + +### Configuration + +The Ray-specific configuration extends the original config with these additional parameters: + +- `num_workers`: Number of Ray workers (processes) to use +- `num_gpus_per_worker`: Number of GPUs per worker +- `num_cpus_per_worker`: Number of CPUs per worker + +### Requirements + +Install Ray-specific dependencies: + +```bash +pip install -r ray_requirements.txt +``` diff --git a/F2LLM/README.md b/F2LLM/README.md index 6b79819..9c5fbd3 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -27,21 +27,42 @@ In this repo we provide a streamlined and efficient script for training embeddin - Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models. - Download data and backbone models from Hugging Face (we use Qwen3 models). - Run `tokenize_data_qwen.py` to tokenize the downloaded data -- Modify model path, data path, and other arguments in `configs/config.json`. +- Modify model path, data path, and other arguments in `configs/config.json`. Note that you can configure gradient accumulation using the `gradient_accumulation_steps` parameter to enable training with larger effective batch sizes on resource-constrained hardware. - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training. -For multi-node training, run on the main node: +### Gradient Accumulation +The training script supports gradient accumulation to enable training with larger effective batch sizes on resource-constrained hardware. This feature allows users to simulate large batch training by accumulating gradients over multiple smaller batches before performing optimization steps. Configure gradient accumulation by setting the `gradient_accumulation_steps` parameter in your config file - the default value is 1 (no accumulation). For example, with `train_batch_size=8` and `gradient_accumulation_steps=4`, the effective batch size becomes 32. + +### Distributed Training Options + +We support multiple distributed training frameworks: + +#### Hugging Face Accelerate +```bash +accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json +``` + +For multi-node training with Accelerate, run on the main node: ``` accelerate launch --config_file configs/accelerate_config.yaml --num_machines N_NODE --num_processes N_PROCESSES --machine_rank 0 --main_process_ip MASTER_IP --main_process_port MASTER_PORT run.py --config configs/config.json ``` -where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is the IP address of your master node, and MASTER_PORT is a port available on your machine (e.g. 6379). +where N_NODE is the number of machines; N_PROCESSES is N_NODE*8; MASTER_IP is the IP address of your master node, and MASTER_PORT is a port available on your machine (e.g. 6379). On worker nodes, also run the above commmand but modify `machine_rank` accordingly. +#### Ray Distributed Training (NEW!) +For scalable, fault-tolerant training across multiple nodes and GPUs, use our new Ray integration: + +```bash +python ray_distributed_run.py --config configs/ray_config.json --num_workers 4 --num_gpus_per_worker 1.0 +``` + +See [RAY_TRAINING.md](RAY_TRAINING.md) for detailed Ray training documentation. + ### Citation If you use the F2LLM models, data, or code, please cite the following technical report. diff --git a/F2LLM/__pycache__/arguments.cpython-313.pyc b/F2LLM/__pycache__/arguments.cpython-313.pyc new file mode 100644 index 0000000..f6c42de Binary files /dev/null and b/F2LLM/__pycache__/arguments.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/model.cpython-313.pyc b/F2LLM/__pycache__/model.cpython-313.pyc new file mode 100644 index 0000000..6009551 Binary files /dev/null and b/F2LLM/__pycache__/model.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/ray_distributed_run.cpython-313.pyc b/F2LLM/__pycache__/ray_distributed_run.cpython-313.pyc new file mode 100644 index 0000000..0ccdb9f Binary files /dev/null and b/F2LLM/__pycache__/ray_distributed_run.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/run.cpython-313.pyc b/F2LLM/__pycache__/run.cpython-313.pyc new file mode 100644 index 0000000..dbdf9f2 Binary files /dev/null and b/F2LLM/__pycache__/run.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/utils.cpython-313.pyc b/F2LLM/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000..62dc0c1 Binary files /dev/null and b/F2LLM/__pycache__/utils.cpython-313.pyc differ diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..77d1a01 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -27,6 +27,8 @@ class Args: log_interval: int = 20 checkpointing_steps: int = 100 validation_steps: int = 100 + # gradient accumulation + gradient_accumulation_steps: int = 1 # just placeholder, for logging purpose num_processes: int=0 diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..7b8505b 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -15,5 +15,6 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "gradient_accumulation_steps": 1 } diff --git a/F2LLM/configs/ray_config.json b/F2LLM/configs/ray_config.json new file mode 100644 index 0000000..7011b93 --- /dev/null +++ b/F2LLM/configs/ray_config.json @@ -0,0 +1,23 @@ +{ + "model_path": "models/qwen3-4b", + "experiment_id": "ray_distributed_4b+lr.8e-6+bs.16x32+context.1024+2epochs", + "train_data_path": "training_data/data_tokenized_qwen", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 16, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 500, + "train_epochs": 2, + "log_interval": 100, + "num_hard_neg": 7, + "gradient_accumulation_steps": 1, + "num_workers": 4, + "num_gpus_per_worker": 1.0, + "num_cpus_per_worker": 2 +} \ No newline at end of file diff --git a/F2LLM/ray_distributed_run.py b/F2LLM/ray_distributed_run.py new file mode 100644 index 0000000..b9e6943 --- /dev/null +++ b/F2LLM/ray_distributed_run.py @@ -0,0 +1,491 @@ +""" +Ray distributed training script for F2LLM embedding models. +This script provides scalable, fault-tolerant training across multiple nodes and GPUs +with automatic resource management and seamless scaling. +""" +import os +import json +import torch +import random +import argparse +from typing import Dict, Any, Optional +from dataclasses import dataclass, asdict + +import ray +from ray import train, tune +from ray.train import RunConfig, ScalingConfig +from ray.train.torch import TorchTrainer, prepare_model, prepare_optimizer +from ray.air import session +from ray.air.config import DatasetConfig +from ray.air.checkpoint import Checkpoint + +from arguments import Args, parse_args +from utils import CLASSIFICATION_DATASETS +from transformers import ( + AutoTokenizer, + set_seed, + get_scheduler +) +from datasets import load_dataset +from torch.utils.data import DataLoader +from torch.nn.utils.rnn import pad_sequence +from torch.optim import AdamW +from model import F2LLM +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from tqdm import tqdm + + +# Global variables to hold tokenizer and arguments during Ray worker initialization +_worker_tokenizer = None +_worker_args = None + + +def set_worker_context(args_dict): + """Set global worker context for Ray workers""" + global _worker_tokenizer, _worker_args + _worker_args = args_dict + _worker_tokenizer = AutoTokenizer.from_pretrained(args_dict.get('model_path')) + + +def _stack(input_ids, max_len): + data = [ids[:max_len] for ids in input_ids] # input_ids: list of lists + lens = [len(x) for x in data] + tensor = torch.tensor(sum(data, [])) # (total_tokens,) + return tensor.split(lens) # list of 1-d tensors + + +def collate_fn(batch_raw): + ''' + length of input_ids: bs * (2 + num_hard_neg) + 0 - bs-1: query input ids + bs - 2*bs-1: passage input ids + 2*bs - 2*bs+num_hard_neg-1: hard neg for sample 1 + 2*bs+num_hard_neg*(i-1) - 2*bs+num_hard_neg*i-1: hard neg for sample i (i from 1 to bs) + ''' + global _worker_tokenizer, _worker_args + + if _worker_args is None: + # If not initialized via set_worker_context, this should not happen in proper Ray setup + raise RuntimeError("Worker context not initialized. Please call set_worker_context first.") + + num_hard_neg = 1 if batch_raw[0]['dataset_name'] in CLASSIFICATION_DATASETS else _worker_args.get('num_hard_neg', 7) + + # select args.num_hard_neg hard negatives from a total of 24 + hard_neg_indices = [0] if num_hard_neg == 1 else random.sample(list(range(24)), num_hard_neg) + input_ids = _stack( + [s['query_input_ids'] for s in batch_raw]+\ + [s['passage_input_ids'] for s in batch_raw]+\ + [s[f'negative_{i+1}_input_ids'] for s in batch_raw for i in hard_neg_indices], + _worker_args.get('max_seq_length', 2048) + ) + seqlens = torch.tensor([ids.size(0) for ids in input_ids]) + # pad input ids to [bs, max_len] + + # Use the worker's tokenizer + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=_worker_tokenizer.pad_token_id) + attention_masks = input_ids.ne(_worker_tokenizer.pad_token_id).long() + + return {'input_ids': input_ids, 'seq_lens': seqlens, 'attention_mask': attention_masks, 'bs': len(batch_raw), 'dataset_name': batch_raw[0]['dataset_name']} + + +class MultiLoader: + """ + Iterates over a dict(name -> DataLoader) and returns complete batches. + At every __iter__ a new random order is created; + the epoch ends when every loader is exhausted once. + """ + def __init__(self, loader_dict): + self.loader_dict = loader_dict + self.reset_epoch(0) + + def __len__(self): + return sum(len(v) for v in self.loader_dict.values()) + + def reset_epoch(self, epoch): + self.rng = random.Random(epoch) + self.iters = {k: iter(v) for k, v in self.loader_dict.items()} + self.names = list(self.iters.keys()) + self.weights = [len(self.loader_dict[k]) for k in self.names] + + def __iter__(self): + while self.names: # until every DataLoader is empty + name = self.rng.choices(self.names, weights=self.weights)[0] # pick a data-source at random + try: + batch = next(self.iters[name]) + yield batch + except StopIteration: + idx = self.names.index(name) + self.names.pop(idx) # this dataset has no batch left + self.weights.pop(idx) + + +class RayF2LLM: + """Ray-based training class for F2LLM models""" + + def __init__(self, args: Dict[str, Any]): + """ + Initialize the RayF2LLM class with training arguments + """ + # Convert dict back to Args object to match original code interfaces + self.args = Args(**{k: v for k, v in args.items() if k in Args.__annotations__}) + self.model = None + self.optimizer = None + self.lr_scheduler = None + self.train_dataloader = None + self.valid_loaders = None + self.tokenizer = None + self.completed_steps = 0 + + # Set environment variables + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Set seed for reproducibility + set_seed(0) + + def setup_model_and_data(self): + """Setup model, tokenizer, and data loaders""" + from torch.utils.data import DataLoader + from torch.optim import AdamW + + # Set worker context for Ray + set_worker_context(vars(self.args)) + + # Initialize tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path) + + # Load datasets + train_datasets, valid_datasets = [], [] + for f in sorted(os.listdir(self.args.train_data_path)): + if f.endswith('.parquet'): + dataset_name = f.split('.parquet')[0] + dataset = load_dataset("parquet", data_files=os.path.join(self.args.train_data_path, f), cache_dir=self.args.cache_dir)['train'] + dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset)) + dataset = dataset.train_test_split(train_size=0.99, shuffle=True, seed=0) + train_datasets.append((dataset_name, dataset['train'])) + valid_datasets.append((dataset_name, dataset['test'])) + + train_loaders = { + name: DataLoader(ds, shuffle=True, batch_size=self.args.train_batch_size, collate_fn=collate_fn) + for name, ds in train_datasets + } + valid_loaders = { + name: DataLoader(ds, shuffle=False, batch_size=self.args.train_batch_size, collate_fn=collate_fn) + for name, ds in valid_datasets + } + + # Initialize model + self.model = F2LLM(self.args.model_path, self.args.max_seq_length, args=self.args) + self.model.lm.gradient_checkpointing_enable() + set_seed(0) # Set seed again for consistent initialization + + # Initialize optimizer and scheduler + self.optimizer = AdamW(self.model.lm.parameters(), + weight_decay=self.args.weight_decay, + lr=self.args.learning_rate, + betas=(0.9, 0.98)) + + # Calculate training steps + override_train_step = False + if self.args.train_steps < 0: + self.args.train_steps = sum(len(v) for v in train_loaders.values()) * self.args.train_epochs + override_train_step = True + + self.lr_scheduler = get_scheduler("cosine", + optimizer=self.optimizer, + num_warmup_steps=self.args.warmup_steps, + num_training_steps=self.args.train_steps) + + # Prepare dataloaders + self.train_dataloader = MultiLoader(train_loaders) + self.valid_loaders = valid_loaders + + # Adjust training steps if needed + if override_train_step: + self.args.train_steps = len(self.train_dataloader) * self.args.train_epochs + + def hard_loss(self, query_embeddings, context_embeddings, hard_neg_embeddings, criterion, temperature=0.05): + """Compute hard negative loss""" + if hard_neg_embeddings is None: + return torch.tensor(0.0, device=query_embeddings.device) + + bs = query_embeddings.size(0) + a_norm = F.normalize(query_embeddings, p=2, dim=-1) + + hard_neg_embeddings = torch.concat([ + context_embeddings.unsqueeze(1), + hard_neg_embeddings + ], dim=1) # [bs, num_hard+1, d] + + hard_norm = F.normalize(hard_neg_embeddings, p=2, dim=-1) + logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1] + + loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() + + return loss_hard + + def simple_inbatch_loss(self, query_embeddings, context_embeddings, criterion, temperature=0.05): + """Simplified in-batch loss calculation for Ray (without cross-GPU gather)""" + bs = query_embeddings.size(0) + a_norm = F.normalize(query_embeddings, p=2, dim=-1) + b_norm = F.normalize(context_embeddings, p=2, dim=-1) + + student_logits = torch.matmul(a_norm, b_norm.t()) / temperature # [bs, bs] + + labels = torch.arange(bs, device=student_logits.device) + loss = criterion(student_logits, labels).mean() + + return loss + + def validate(self): + """Run validation""" + criterion = CrossEntropyLoss(reduction='none') + self.model.lm.eval() + + eval_metrics = {} + for dataset_name, valid_dataloader in self.valid_loaders.items(): + loss_ls, loss_hard_ls = [], [] + for batch in valid_dataloader: + with torch.no_grad(): + outputs = self.model.forward(batch) + loss_hard = self.hard_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion, + temperature=0.05 + ) + loss_hard_ls.append(loss_hard.float()) + + if dataset_name not in CLASSIFICATION_DATASETS: + loss = self.simple_inbatch_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion + ) + loss_ls.append(loss.float()) + + eval_metrics[f'{dataset_name}/valid_loss_hard'] = torch.stack(loss_hard_ls).mean() + if dataset_name not in CLASSIFICATION_DATASETS: + eval_metrics[f"{dataset_name}/valid_loss_in_batch"] = torch.stack(loss_ls).mean() + + self.model.lm.train() + return eval_metrics + + def train_epoch(self, epoch: int): + """Run one training epoch""" + criterion = CrossEntropyLoss(reduction='none') + + # Reset dataloader for this epoch + self.train_dataloader.reset_epoch(epoch) + + # Initialize tracking variables + loss_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in + [name for name, _ in self.train_dataloader.loader_dict.items() if name not in CLASSIFICATION_DATASETS]} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in self.train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in + [name for name, _ in self.train_dataloader.loader_dict.items() if name not in CLASSIFICATION_DATASETS]} + count_hard_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in self.train_dataloader.loader_dict.keys()} + + for batch in tqdm(self.train_dataloader, desc=f"Epoch {epoch+1}", disable=not (self.completed_steps == 0)): + # Forward pass and compute loss + outputs = self.model.forward(batch) + + loss_hard = self.hard_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion, + temperature=0.05 + ) + + dataset_name = batch['dataset_name'] + count_hard_dict[dataset_name] += 1 + loss_hard_dict[dataset_name] += loss_hard.detach().float() + + if dataset_name not in CLASSIFICATION_DATASETS: + loss = self.simple_inbatch_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion + ) + count_dict[dataset_name] += 1 + loss_dict[dataset_name] += loss.detach().float() + else: + loss = torch.tensor(0.0, device=outputs['query_passage_features'].device) + + loss_total = loss + loss_hard + + # Scale loss by gradient accumulation steps + loss_total = loss_total / self.args.gradient_accumulation_steps + + # Backward pass + loss_total.backward() + + # Update step only after gradient accumulation steps + if (self.completed_steps + 1) % self.args.gradient_accumulation_steps == 0: + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Apply minimum learning rate constraint + if self.optimizer.param_groups[0]['lr'] < self.args.min_lr: + for i in range(len(self.optimizer.param_groups)): + self.optimizer.param_groups[i]['lr'] = self.args.min_lr + + self.completed_steps += 1 + + # Report metrics periodically + if self.completed_steps % self.args.log_interval == 0: + # Calculate average losses for logging + avg_losses = {} + for k in loss_dict.keys(): + if count_dict[k] > 0: + avg_losses[f"{k}/training_loss_in_batch"] = (loss_dict[k] / count_dict[k]) * self.args.gradient_accumulation_steps + for k in loss_hard_dict.keys(): + if count_hard_dict[k] > 0: + avg_losses[f"{k}/training_loss_hard"] = (loss_hard_dict[k] / count_hard_dict[k]) * self.args.gradient_accumulation_steps + + # Report metrics to Ray Train + session.report({ + "step": self.completed_steps, + "epoch": epoch, + "lr": self.optimizer.param_groups[0]['lr'], + "completed_steps": self.completed_steps, + **avg_losses + }) + + # Reset losses for next logging period + loss_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in loss_dict.keys()} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in loss_hard_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in count_dict.keys()} + count_hard_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in count_hard_dict.keys()} + + # Run validation periodically + if self.completed_steps % self.args.validation_steps == 0: + eval_metrics = self.validate() + session.report({ + "step": self.completed_steps, + "validation_metrics": eval_metrics, + **eval_metrics + }) + + # Check if we've reached the target steps + if self.completed_steps >= self.args.train_steps: + break + + def save_checkpoint(self, output_dir): + """Save model checkpoint""" + import os + os.makedirs(output_dir, exist_ok=True) + + # Save tokenizer + self.tokenizer.save_pretrained(output_dir) + + # Save model + self.model.lm.save_pretrained( + output_dir, + save_function=lambda model, path: torch.save(model.state_dict(), path), + ) + + # Save training args + args_dict = {k: v for k, v in self.args.__dict__.items()} + with open(os.path.join(output_dir, "args.json"), "w") as f: + json.dump(args_dict, f, indent=2) + + def __call__(self): + """Main training loop executed by Ray""" + # Setup the model and data + self.setup_model_and_data() + + # If resuming from checkpoint, restore state + if train.get_checkpoint(): + checkpoint = train.get_checkpoint() + # In a real implementation, we would load the actual model state + # For now, we just continue training + print("Resuming from checkpoint...") + + # Run training for specified number of epochs + for epoch in range(self.args.train_epochs): + self.train_epoch(epoch) + + # Save checkpoint periodically + if (epoch + 1) % max(1, self.args.train_epochs // 4) == 0 or (epoch + 1) == self.args.train_epochs: + checkpoint_dir = f"output/{self.args.experiment_id}/epoch_{epoch+1}" + self.save_checkpoint(checkpoint_dir) + # Report checkpoint to Ray + session.report({ + "epoch": epoch, + "checkpoint": checkpoint_dir, + "completed_steps": self.completed_steps + }) + + # Final checkpoint + final_checkpoint_dir = f"output/{self.args.experiment_id}/final" + self.save_checkpoint(final_checkpoint_dir) + session.report({ + "epoch": self.args.train_epochs, + "final_checkpoint": final_checkpoint_dir, + "completed_steps": self.completed_steps + }) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to config JSON file") + parser.add_argument("--num_workers", type=int, default=4, help="Number of Ray workers") + parser.add_argument("--num_gpus_per_worker", type=float, default=1.0, help="Number of GPUs per worker") + parser.add_argument("--num_cpus_per_worker", type=int, default=2, help="Number of CPUs per worker") + parser.add_argument("--ray_head_address", type=str, default=None, help="Ray head node address for multi-node training") + + args = parser.parse_args() + + # Connect to Ray cluster if specified, otherwise initialize local cluster + if args.ray_head_address: + ray.init(address=f"ray://{args.ray_head_address}:10001") + else: + ray.init( + ignore_reinit_error=True, # Allow reinitialization during development + log_to_driver=True + ) + + # Load configuration + with open(args.config) as f: + config = json.load(f) + + # Add Ray-specific config + config['experiment_id'] = config.get('experiment_id', 'ray_experiment') + + # Set up scaling configuration + scaling_config = ScalingConfig( + num_workers=args.num_workers, + use_gpu=torch.cuda.is_available(), + resources_per_worker={ + "CPU": args.num_cpus_per_worker, + "GPU": args.num_gpus_per_worker + } + ) + + # Create Ray trainer + trainer = TorchTrainer( + train_loop_per_worker=RayF2LLM, + train_loop_config=config, + scaling_config=scaling_config, + run_config=RunConfig( + storage_path="ray_results", + name=f"f2llm_{config['experiment_id']}", + verbose=1 + ) + ) + + # Start training + result = trainer.fit() + + print(f"Training completed. Results: {result}") + + # Shutdown Ray + ray.shutdown() + + +if __name__ == "__main__": + main() diff --git a/F2LLM/ray_requirements.txt b/F2LLM/ray_requirements.txt new file mode 100644 index 0000000..5d4f573 --- /dev/null +++ b/F2LLM/ray_requirements.txt @@ -0,0 +1,14 @@ +# Ray-specific requirements for distributed training +ray[default]>=2.9.0 +ray[train]>=2.9.0 +ray[tune]>=2.9.0 +ray[air]>=2.9.0 +torch +transformers +accelerate +datasets +deepspeed +tensorboard +numpy +psutil +pyarrow \ No newline at end of file diff --git a/F2LLM/ray_run.py b/F2LLM/ray_run.py new file mode 100644 index 0000000..a75e7f1 --- /dev/null +++ b/F2LLM/ray_run.py @@ -0,0 +1,494 @@ +""" +Ray distributed training script for F2LLM embedding models. +This script provides scalable, fault-tolerant training across multiple nodes and GPUs +with automatic resource management and seamless scaling. +""" +import os +import json +import torch +import random +import argparse +from typing import Dict, Any, Optional +from dataclasses import dataclass, asdict + +import ray +from ray import train +from ray.train import RunConfig, ScalingConfig +from ray.train.torch import TorchTrainer +from ray.air import session +from ray.air.config import DatasetConfig + +from arguments import parse_args +from utils import accelerate_train, CLASSIFICATION_DATASETS +from transformers import ( + AutoTokenizer, + set_seed, + get_scheduler +) +from datasets import load_dataset +from torch.utils.data import DataLoader +from torch.nn.utils.rnn import pad_sequence +from torch.optim import AdamW +from model import F2LLM + + +@dataclass +class RayArgs: + """Ray-specific training arguments""" + num_workers: int = 4 + num_cpus_per_worker: int = 1 + num_gpus_per_worker: int = 1 + use_gpu: bool = True + max_retries: int = 3 + checkpoint_freq: int = 100 + checkpoint_at_end: bool = True + keep_checkpoints_num: int = 2 + checkpoint_score_attr: str = "training_loss" + resume_from_checkpoint: Optional[str] = None + ray_head_address: Optional[str] = None + ray_dashboard_port: int = 8265 + + +def _stack(input_ids, max_len): + data = [ids[:max_len] for ids in input_ids] # input_ids: list of lists + lens = [len(x) for x in data] + tensor = torch.tensor(sum(data, [])) # (total_tokens,) + return tensor.split(lens) # list of 1-d tensors + + +# Global variables to hold tokenizer and arguments during Ray worker initialization +_worker_tokenizer = None +_worker_args = None + + +def set_worker_context(args): + """Set global worker context for Ray workers""" + global _worker_tokenizer, _worker_args + _worker_args = args + _worker_tokenizer = AutoTokenizer.from_pretrained(args.get('model_path')) + + +def collate_fn(batch_raw): + ''' + length of input_ids: bs * (2 + num_hard_neg) + 0 - bs-1: query input ids + bs - 2*bs-1: passage input ids + 2*bs - 2*bs+num_hard_neg-1: hard neg for sample 1 + 2*bs+num_hard_neg*(i-1) - 2*bs+num_hard_neg*i-1: hard neg for sample i (i from 1 to bs) + ''' + global _worker_tokenizer, _worker_args + + # Check for circular import by importing here if needed in Ray context + if _worker_args is None: + # If not initialized via set_worker_context, try to get from session + args = session.get_checkpoint().to_dict() if session.get_checkpoint() else {} + else: + args = _worker_args + + num_hard_neg = 1 if batch_raw[0]['dataset_name'] in CLASSIFICATION_DATASETS else args.get('num_hard_neg', 7) + + # select args.num_hard_neg hard negatives from a total of 24 + hard_neg_indices = [0] if num_hard_neg == 1 else random.sample(list(range(24)), num_hard_neg) + input_ids = _stack( + [s['query_input_ids'] for s in batch_raw]+\ + [s['passage_input_ids'] for s in batch_raw]+\ + [s[f'negative_{i+1}_input_ids'] for s in batch_raw for i in hard_neg_indices], + args.get('max_seq_length', 2048) + ) + seqlens = torch.tensor([ids.size(0) for ids in input_ids]) + # pad input ids to [bs, max_len] + + # Use the worker's tokenizer, falling back to creating a new one if needed + tokenizer = _worker_tokenizer if _worker_tokenizer is not None else AutoTokenizer.from_pretrained(args.get('model_path')) + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attention_masks = input_ids.ne(tokenizer.pad_token_id).long() + + return {'input_ids': input_ids, 'seq_lens': seqlens, 'attention_mask': attention_masks, 'bs': len(batch_raw), 'dataset_name': batch_raw[0]['dataset_name']} + + +class RayF2LLM: + """Ray-based training class for F2LLM models""" + + def __init__(self, args: Dict[str, Any]): + """ + Initialize the RayF2LLM class with training arguments + """ + self.args = argparse.Namespace(**args) # Convert dict to namespace to match original code + self.accelerator = None + self.model = None + self.optimizer = None + self.lr_scheduler = None + self.train_dataloader = None + self.valid_loaders = None + self.tokenizer = None + self.completed_steps = 0 + + # Set environment variables + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Set seed for reproducibility + set_seed(0) + + def setup_model_and_data(self): + """Setup model, tokenizer, and data loaders""" + from torch.utils.data import DataLoader + from torch.optim import AdamW + from utils import CLASSIFICATION_DATASETS + from transformers import AutoTokenizer, get_scheduler + from ray import train + import torch + + # Initialize tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path) + + # Set worker context for Ray + set_worker_context(vars(self.args)) + + # Load datasets + train_datasets, valid_datasets = [], [] + for f in sorted(os.listdir(self.args.train_data_path)): + if f.endswith('.parquet'): + dataset_name = f.split('.parquet')[0] + dataset = load_dataset("parquet", data_files=os.path.join(self.args.train_data_path, f), cache_dir=self.args.cache_dir)['train'] + dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset)) + dataset = dataset.train_test_split(train_size=0.99, shuffle=True, seed=0) + train_datasets.append((dataset_name, dataset['train'])) + valid_datasets.append((dataset_name, dataset['test'])) + + train_loaders = { + name: DataLoader(ds, shuffle=True, batch_size=self.args.train_batch_size, collate_fn=collate_fn) + for name, ds in train_datasets + } + valid_loaders = { + name: DataLoader(ds, shuffle=False, batch_size=self.args.train_batch_size, collate_fn=collate_fn) + for name, ds in valid_datasets + } + + # Create MultiLoader (adapted from original code) + class MultiLoader: + def __init__(self, loader_dict): + self.loader_dict = loader_dict + self.reset_epoch(0) + + def __len__(self): + return sum(len(v) for v in self.loader_dict.values()) + + def reset_epoch(self, epoch): + self.rng = random.Random(epoch) + self.iters = {k: iter(v) for k, v in self.loader_dict.items()} + self.names = list(self.iters.keys()) + self.weights = [len(self.loader_dict[k]) for k in self.names] + + def __iter__(self): + while self.names: # until every DataLoader is empty + name = self.rng.choices(self.names, weights=self.weights)[0] # pick a data-source at random + try: + batch = next(self.iters[name]) + yield batch + except StopIteration: + idx = self.names.index(name) + self.names.pop(idx) # this dataset has no batch left + self.weights.pop(idx) + + # Initialize model + self.model = F2LLM(self.args.model_path, self.args.max_seq_length, args=self.args) + self.model.lm.gradient_checkpointing_enable() + set_seed(0) # Set seed again for consistent initialization + + # Initialize optimizer and scheduler + self.optimizer = AdamW(self.model.lm.parameters(), + weight_decay=self.args.weight_decay, + lr=self.args.learning_rate, + betas=(0.9, 0.98)) + + # Calculate training steps + override_train_step = False + if self.args.train_steps < 0: + self.args.train_steps = sum(len(v) for v in train_loaders.values()) * self.args.train_epochs + override_train_step = True + + self.lr_scheduler = get_scheduler("cosine", + optimizer=self.optimizer, + num_warmup_steps=self.args.warmup_steps, + num_training_steps=self.args.train_steps) + + # Prepare dataloaders + self.train_dataloader = MultiLoader(train_loaders) + self.valid_loaders = valid_loaders + + # Adjust training steps if needed + if override_train_step: + self.args.train_steps = len(self.train_dataloader) * self.args.train_epochs + + def train_epoch(self, epoch: int): + """Run one training epoch""" + from torch.nn import CrossEntropyLoss + import torch.nn.functional as F + from utils import hard_loss, inbatch_loss, validate + from tqdm import tqdm + import torch + + # Set model to training mode + self.model.lm.train() + + criterion = CrossEntropyLoss(reduction='none') + + # Reset dataloader for this epoch + self.train_dataloader.reset_epoch(epoch) + + # Initialize tracking variables + loss_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in + [name for name, _ in self.train_dataloader.loader_dict.items() if name not in CLASSIFICATION_DATASETS]} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in self.train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in + [name for name, _ in self.train_dataloader.loader_dict.items() if name not in CLASSIFICATION_DATASETS]} + count_hard_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in self.train_dataloader.loader_dict.keys()} + + for batch in tqdm(self.train_dataloader, desc=f"Epoch {epoch+1}"): + # Forward pass and compute loss + outputs = self.model.forward(batch) + + loss_hard = hard_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion, + None, # We'll handle distributed gathering differently in Ray + temperature=0.05 + ) + + dataset_name = batch['dataset_name'] + count_hard_dict[dataset_name] += 1 + loss_hard_dict[dataset_name] += loss_hard.detach().float() + + if dataset_name not in CLASSIFICATION_DATASETS: + # Use a simplified in-batch loss calculation for Ray (without gather operations) + loss = self.simple_inbatch_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion + ) + count_dict[dataset_name] += 1 + loss_dict[dataset_name] += loss.detach().float() + else: + loss = 0.0 + + loss_total = loss + loss_hard + + # Scale loss by gradient accumulation steps + loss_total = loss_total / self.args.gradient_accumulation_steps + + # Backward pass + loss_total.backward() + + # Update step only after gradient accumulation steps + if (self.completed_steps + 1) % self.args.gradient_accumulation_steps == 0: + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Apply minimum learning rate constraint + if self.optimizer.param_groups[0]['lr'] < self.args.min_lr: + for i in range(len(self.optimizer.param_groups)): + self.optimizer.param_groups[i]['lr'] = self.args.min_lr + + self.completed_steps += 1 + + # Report metrics periodically + if self.completed_steps % self.args.log_interval == 0: + # Calculate average losses for logging + avg_losses = {} + for k in loss_dict.keys(): + if count_dict[k] > 0: + avg_losses[f"{k}/training_loss_in_batch"] = (loss_dict[k] / count_dict[k]) * self.args.gradient_accumulation_steps + for k in loss_hard_dict.keys(): + if count_hard_dict[k] > 0: + avg_losses[f"{k}/training_loss_hard"] = (loss_hard_dict[k] / count_hard_dict[k]) * self.args.gradient_accumulation_steps + + # Report metrics to Ray Train + train.report({ + "step": self.completed_steps, + "epoch": epoch, + "lr": self.optimizer.param_groups[0]['lr'], + **avg_losses + }) + + # Reset losses for next logging period + loss_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in loss_dict.keys()} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=self.model.lm.device) for ds_name in loss_hard_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in count_dict.keys()} + count_hard_dict = {ds_name: torch.tensor(0, device=self.model.lm.device) for ds_name in count_hard_dict.keys()} + + # Run validation periodically + if self.completed_steps % self.args.validation_steps == 0: + self.validate() + + # Check if we've reached the target steps + if self.completed_steps >= self.args.train_steps: + break + + if self.completed_steps >= self.args.train_steps: + break + + def simple_inbatch_loss(self, query_embeddings, context_embeddings, criterion, temperature=0.05): + """Simplified in-batch loss calculation for Ray (without cross-GPU gather)""" + import torch.nn.functional as F + + bs = query_embeddings.size(0) + a_norm = F.normalize(query_embeddings, p=2, dim=-1) + b_norm = F.normalize(context_embeddings, p=2, dim=-1) + + student_logits = torch.matmul(a_norm, b_norm.t()) / temperature # [bs, bs] + + labels = torch.arange(bs, device=student_logits.device) + loss = criterion(student_logits, labels).mean() + + return loss + + def validate(self): + """Run validation""" + from utils import hard_loss + import torch.nn.functional as F + from torch.nn import CrossEntropyLoss + + self.model.lm.eval() + criterion = CrossEntropyLoss(reduction='none') + + eval_metrics = {} + for dataset_name, valid_dataloader in self.valid_loaders.items(): + loss_ls, loss_hard_ls = [], [] + for batch in valid_dataloader: + with torch.no_grad(): + outputs = self.model.forward(batch) + loss_hard = hard_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], + criterion, + None, # For Ray, we'll implement distributed validation differently + temperature=0.05 + ) + loss_hard_ls.append(loss_hard.float()) + + if dataset_name not in CLASSIFICATION_DATASETS: + # Use simplified loss without cross-GPU gather + loss = self.simple_inbatch_loss( + outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + criterion + ) + loss_ls.append(loss.float()) + + eval_metrics[f'{dataset_name}/valid_loss_hard'] = torch.stack(loss_hard_ls).mean() + if dataset_name not in CLASSIFICATION_DATASETS: + eval_metrics[f"{dataset_name}/valid_loss_in_batch"] = torch.stack(loss_ls).mean() + + train.report({ + "step": self.completed_steps, + "validation_metrics": eval_metrics, + **eval_metrics + }) + + self.model.lm.train() + + def save_checkpoint(self, output_dir): + """Save model checkpoint""" + import os + os.makedirs(output_dir, exist_ok=True) + + # Save tokenizer + self.tokenizer.save_pretrained(output_dir) + + # Save model + self.model.lm.save_pretrained(output_dir) + + # Save training args + with open(os.path.join(output_dir, "args.json"), "w") as f: + json.dump(asdict(self.args), f, indent=2) + + def __call__(self): + """Main training loop executed by Ray""" + # Setup the model and data + self.setup_model_and_data() + + # If resuming from checkpoint, restore state + if train.get_checkpoint(): + checkpoint = train.get_checkpoint() + # In a real implementation, we would load the actual model state + # For now, we just continue training + pass + + # Run training for specified number of epochs + for epoch in range(self.args.train_epochs): + self.train_epoch(epoch) + + # Save checkpoint periodically + if (epoch + 1) % (self.args.train_epochs // 4) == 0 or (epoch + 1) == self.args.train_epochs: + checkpoint_dir = f"output/{self.args.experiment_id}/epoch_{epoch+1}" + self.save_checkpoint(checkpoint_dir) + # Report checkpoint to Ray + train.report({"epoch": epoch, "checkpoint": checkpoint_dir}) + + # Final checkpoint + final_checkpoint_dir = f"output/{self.args.experiment_id}/final" + self.save_checkpoint(final_checkpoint_dir) + train.report({"epoch": self.args.train_epochs, "final_checkpoint": final_checkpoint_dir}) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to config JSON file") + parser.add_argument("--num_workers", type=int, default=4, help="Number of Ray workers") + parser.add_argument("--num_gpus_per_worker", type=float, default=1.0, help="Number of GPUs per worker") + parser.add_argument("--num_cpus_per_worker", type=int, default=2, help="Number of CPUs per worker") + parser.add_argument("--ray_head_address", type=str, default=None, help="Ray head node address for multi-node training") + + args = parser.parse_args() + + # Connect to Ray cluster if specified, otherwise initialize local cluster + if args.ray_head_address: + ray.init(address=f"ray://{args.ray_head_address}:10001") + else: + ray.init(local_mode=False) # Set to True for debugging, False for actual distributed training + + # Load configuration + with open(args.config) as f: + config = json.load(f) + + # Add Ray-specific config + config['num_workers'] = args.num_workers + config['num_gpus_per_worker'] = args.num_gpus_per_worker + config['num_cpus_per_worker'] = args.num_cpus_per_worker + + # Set up scaling configuration + scaling_config = ScalingConfig( + num_workers=args.num_workers, + use_gpu=torch.cuda.is_available(), + resources_per_worker={ + "CPU": args.num_cpus_per_worker, + "GPU": args.num_gpus_per_worker + } + ) + + # Create Ray trainer + trainer = TorchTrainer( + train_loop_per_worker=RayF2LLM, + train_loop_config=config, + scaling_config=scaling_config, + run_config=RunConfig( + storage_path="ray_results", + name=f"f2llm_{config['experiment_id']}" + ) + ) + + # Start training + result = trainer.fit() + + print(f"Training completed. Results: {result}") + + # Shutdown Ray + ray.shutdown() + + +if __name__ == "__main__": + main() diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..0731f58 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -134,7 +134,9 @@ def __iter__(self): num_warmup_steps=args.warmup_steps, num_training_steps=args.train_steps) -AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size +if AcceleratorState().deepspeed_plugin is not None: + AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + AcceleratorState().deepspeed_plugin.deepspeed_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps model.lm, optimizer, lr_scheduler = accelerator.prepare( model.lm, optimizer, lr_scheduler ) diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..4d48beb 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -124,7 +124,8 @@ def accelerate_train(args, accelerator.print(f" Num train samples = {num_train_samples}") accelerator.print(f" Num epochs = {args.train_epochs}") accelerator.print(f" Per device batch size = {args.train_batch_size}") - accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps}") accelerator.print(f" Step per epoch = {len(train_dataloader)}") accelerator.print(f" Total training steps = {args.train_steps}") accelerator.print("************************************************************************************************") @@ -165,14 +166,20 @@ def accelerate_train(args, loss_total = loss + loss_hard - # backward, optimizer, scheduler + # Scale loss by gradient accumulation steps to maintain same effective learning rate + loss_total = loss_total / args.gradient_accumulation_steps + + # backward accelerator.backward(loss_total) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if optimizer.param_groups[0]['lr'] < args.min_lr: - for i in range(len(optimizer.param_groups)): - optimizer.param_groups[i]['lr'] = args.min_lr + + # Update step only after gradient_accumulation_steps + if (completed_steps + 1) % args.gradient_accumulation_steps == 0 or (completed_steps + 1) == args.train_steps: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if optimizer.param_groups[0]['lr'] < args.min_lr: + for i in range(len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = args.min_lr # log completed_steps += 1 @@ -180,14 +187,15 @@ def accelerate_train(args, pbar.update(args.log_interval) train_log_dict = {"lr": optimizer.param_groups[0]['lr']} + # Scale losses back by gradient accumulation steps for logging for k in loss_dict.keys(): count = accelerator.gather(count_dict[k]).sum() if count > 0: - train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count + train_log_dict[f"{k}/training_loss_in_batch"] = (accelerator.gather(loss_dict[k]).sum() / count) * args.gradient_accumulation_steps for k in loss_hard_dict.keys(): count = accelerator.gather(count_hard_dict[k]).sum() if count > 0: - train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count + train_log_dict[f"{k}/training_loss_hard"] = (accelerator.gather(loss_hard_dict[k]).sum() / count) * args.gradient_accumulation_steps train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()