| | import logging |
| | import traceback |
| |
|
| | import torch |
| | from datasets import load_dataset |
| |
|
| | from sentence_transformers import SentenceTransformer |
| | from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData |
| | from sentence_transformers.cross_encoder.evaluation import ( |
| | CrossEncoderNanoBEIREvaluator, |
| | CrossEncoderRerankingEvaluator, |
| | ) |
| | from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss |
| | from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer |
| | from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments |
| | from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator |
| | from sentence_transformers.util import mine_hard_negatives |
| |
|
| | |
| | logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) |
| |
|
| |
|
| | def main(): |
| | model_name = "prajjwal1/bert-tiny" |
| |
|
| | train_batch_size = 2048 |
| | num_epochs = 1 |
| | num_hard_negatives = 5 |
| |
|
| | |
| | model = CrossEncoder( |
| | model_name, |
| | model_card_data=CrossEncoderModelCardData( |
| | language="en", |
| | license="apache-2.0", |
| | model_name="BERT-tiny trained on GooAQ", |
| | ), |
| | ) |
| | print("Model max length:", model.max_length) |
| | print("Model num labels:", model.num_labels) |
| |
|
| | |
| | logging.info("Read the gooaq training dataset") |
| | full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000)) |
| | dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12) |
| | train_dataset = dataset_dict["train"] |
| | eval_dataset = dataset_dict["test"] |
| | logging.info(train_dataset) |
| | logging.info(eval_dataset) |
| |
|
| | |
| | embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu") |
| | hard_train_dataset = mine_hard_negatives( |
| | train_dataset, |
| | embedding_model, |
| | num_negatives=num_hard_negatives, |
| | margin=0, |
| | range_min=0, |
| | range_max=100, |
| | sampling_strategy="top", |
| | batch_size=4096, |
| | output_format="labeled-pair", |
| | use_faiss=True, |
| | ) |
| | logging.info(hard_train_dataset) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives)) |
| |
|
| | |
| | nano_beir_evaluator = CrossEncoderNanoBEIREvaluator( |
| | dataset_names=["msmarco", "nfcorpus", "nq"], |
| | batch_size=train_batch_size, |
| | ) |
| |
|
| | |
| | |
| | |
| | hard_eval_dataset = mine_hard_negatives( |
| | eval_dataset, |
| | embedding_model, |
| | corpus=full_dataset["answer"], |
| | num_negatives=30, |
| | batch_size=4096, |
| | disqualify_positives=False, |
| | output_format="n-tuple", |
| | use_faiss=True, |
| | ) |
| | logging.info(hard_eval_dataset) |
| | reranking_evaluator = CrossEncoderRerankingEvaluator( |
| | samples=[ |
| | { |
| | "query": sample["question"], |
| | "positive": [sample["answer"]], |
| | "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]], |
| | } |
| | for sample in hard_eval_dataset |
| | ], |
| | batch_size=train_batch_size, |
| | name="gooaq-dev", |
| | ) |
| |
|
| | |
| | evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator]) |
| | evaluator(model) |
| |
|
| | |
| | short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] |
| | run_name = f"reranker-{short_model_name}-gooaq-bce" |
| | args = CrossEncoderTrainingArguments( |
| | |
| | output_dir=f"models/{run_name}", |
| | |
| | num_train_epochs=num_epochs, |
| | per_device_train_batch_size=train_batch_size, |
| | per_device_eval_batch_size=train_batch_size, |
| | learning_rate=5e-4, |
| | warmup_ratio=0.1, |
| | fp16=False, |
| | bf16=True, |
| | load_best_model_at_end=True, |
| | metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10", |
| | |
| | eval_strategy="steps", |
| | eval_steps=20, |
| | save_strategy="steps", |
| | save_steps=20, |
| | save_total_limit=2, |
| | logging_steps=20, |
| | logging_first_step=True, |
| | run_name=run_name, |
| | seed=12, |
| | ) |
| |
|
| | |
| | trainer = CrossEncoderTrainer( |
| | model=model, |
| | args=args, |
| | train_dataset=hard_train_dataset, |
| | loss=loss, |
| | evaluator=evaluator, |
| | ) |
| | trainer.train() |
| |
|
| | |
| | evaluator(model) |
| |
|
| | |
| | final_output_dir = f"models/{run_name}/final" |
| | model.save_pretrained(final_output_dir) |
| |
|
| | |
| | |
| | try: |
| | model.push_to_hub(f"cross-encoder-testing/{run_name}") |
| | except Exception: |
| | logging.error( |
| | f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " |
| | f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " |
| | f"and saving it using `model.push_to_hub('{run_name}')`." |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|