File size: 1,333 Bytes
4be79fa 3c397fd 5d5c713 3c397fd 4be79fa 5d5c713 3c397fd 5d5c713 3c397fd 5d5c713 3c397fd 5d5c713 3c397fd 0db2cde 3c397fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import os
import torch
from comet import download_model, load_from_checkpoint
# Set a custom cache directory for COMET
os.environ["COMET_CACHE"] = "/tmp"
def calculate_comet(source_sentences, translations, references):
"""
Calculate COMET scores for a list of translations.
:param source_sentences: List of source sentences.
:param translations: List of translated sentences (hypotheses).
:param references: List of reference translations.
:return: List of COMET scores (one score per sentence pair).
"""
try:
# Download and load the COMET model
model_path = download_model("Unbabel/wmt22-comet-da")
model = load_from_checkpoint(model_path)
# Force CPU usage if GPU is not available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Prepare data for COMET
data = [
{"src": src, "mt": mt, "ref": ref}
for src, mt, ref in zip(source_sentences, translations, references)
]
# Compute COMET scores
results = model.predict(data, batch_size=8, gpus=0)
scores = results["scores"]
return scores
except Exception as e:
print(f"COMET calculation error: {str(e)}")
return [0.0] * len(source_sentences) # Return default scores on error |