diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..e9e696e 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -1,7 +1,7 @@ { - "model_path": "models/qwen3-4b", - "experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs", - "train_data_path": "training_data/data_tokenized_qwen", + "model_path": "models/bert", + "experiment_id": "bert+lr.8e-6+bs.16x32+context.1024+1epochs", + "train_data_path": "data_tokenized_bert", "output_dir": "output", "tb_dir": "output/tb", "cache_dir": "cache", diff --git a/F2LLM/model_bert.py b/F2LLM/model_bert.py new file mode 100644 index 0000000..70ec869 --- /dev/null +++ b/F2LLM/model_bert.py @@ -0,0 +1,72 @@ + +import torch +from transformers import AutoModel, AutoTokenizer + +class BertEmbedder: + def __init__(self, + model_path, + max_seq_length=512, + args=None, + pool_strategy="cls" + ): + self.args = args + self.dtype = torch.bfloat16 + self.device = None + self.encoder = AutoModel.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype + ) + if hasattr(self.encoder.config, "use_cache"): + self.encoder.config.use_cache = False + + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + if self.tokenizer.pad_token is None: + if self.tokenizer.eos_token is not None: + self.tokenizer.pad_token = self.tokenizer.eos_token + elif self.tokenizer.sep_token is not None: + self.tokenizer.pad_token = self.tokenizer.sep_token + elif self.tokenizer.cls_token is not None: + self.tokenizer.pad_token = self.tokenizer.cls_token + + self.max_seq_length = max_seq_length + self.pool_strategy = pool_strategy + + def set_device(self): + self.device = self.encoder.device + + def _pool(self, last_hidden_state, attention_mask): + # last\_hidden\_state: [bs, seq, d], attention\_mask: [bs, seq] + if self.pool_strategy == "mean": + mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [bs, seq, 1] + summed = (last_hidden_state * mask).sum(dim=1) # [bs, d] + denom = mask.sum(dim=1).clamp_min(1e-6) # [bs, 1] + return summed / denom + return last_hidden_state[:, 0, :] # [CLS] + + def forward(self, batch): + bs = batch['bs'] + num_hard_neg = int((len(batch['input_ids']) - 2 * bs) / bs) + + outputs = self.encoder( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'] + ) + last_hidden_state = outputs.last_hidden_state + attn = batch['attention_mask'] + + query_emb = self._pool(last_hidden_state[:bs], attn[:bs]).unsqueeze(1) # [bs, 1, d] + passage_emb = self._pool(last_hidden_state[bs:2*bs], attn[bs:2*bs]).unsqueeze(1) # [bs, 1, d] + + if num_hard_neg == 0: + neg_emb = None + else: + neg_all = self._pool(last_hidden_state[2*bs:], attn[2*bs:]) # [bs*num\_hard\_neg, d] + neg_emb = neg_all.view(bs, num_hard_neg, -1) # [bs, num\_hard\_neg, d] + + return { + 'query_passage_features': query_emb, + 'passage_passage_features': passage_emb, + 'negative_passage_features': neg_emb + } \ No newline at end of file diff --git a/F2LLM/tokenize_data_bert.py b/F2LLM/tokenize_data_bert.py new file mode 100644 index 0000000..d9080bd --- /dev/null +++ b/F2LLM/tokenize_data_bert.py @@ -0,0 +1,57 @@ +from multiprocessing import Pool +import numpy as np +import pandas as pd +import os +from transformers import AutoTokenizer +from tqdm.auto import tqdm + + +tokenizer = AutoTokenizer.from_pretrained('models/qwen3-0.6b') +max_seq_length = 1023 + +output_dir = 'data_tokenized_qwen' +os.makedirs(output_dir, exist_ok=True) + + +def process_sent(sentence): + + tokenizer_outputs = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=False) + + return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id]) + + +def process_sent_batch(s): + return s.apply(process_sent) + +def parallelize(data, func, num_of_processes=8): + indices = np.array_split(data.index, num_of_processes) + data_split = [data.iloc[idx] for idx in indices] + with Pool(num_of_processes) as pool: + data = pd.concat(pool.map(func, data_split)) + return data + + +root_dir = 'datasets' + +for ds_name in tqdm(sorted(parquet_files)): + print(ds_name, flush=True) + + df = pd.read_parquet(f"{root_dir}/{ds_name}") + df['query_input_ids'] = parallelize(df['query'], process_sent_batch, 62) + + num_neg = 24 if 'negative_2' in df.keys() else 1 + + ls = df.passage.to_list() + for i in range(1, num_neg+1): + ls += df[f'negative_{i}'].to_list() + ls = list(set(ls)) + df_tmp = pd.DataFrame({'text': ls}) + df_tmp['input_ids'] = parallelize(df_tmp['text'], process_sent_batch, 62) + df_tmp = df_tmp.set_index('text') + + df['passage_input_ids'] = df.passage.map(df_tmp.input_ids) + + for i in range(1, num_neg+1): + df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids) + + df.to_parquet(f'{output_dir}/{ds_name}', index=False)