File size: 1,333 Bytes
4be79fa
3c397fd
 
5d5c713
3c397fd
 
4be79fa
5d5c713
 
3c397fd
5d5c713
 
 
 
 
3c397fd
 
 
 
5d5c713
3c397fd
 
 
5d5c713
3c397fd
 
 
 
 
0db2cde
3c397fd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import torch
from comet import download_model, load_from_checkpoint

# Set a custom cache directory for COMET
os.environ["COMET_CACHE"] = "/tmp"

def calculate_comet(source_sentences, translations, references):
    """
    Calculate COMET scores for a list of translations.
    :param source_sentences: List of source sentences.
    :param translations: List of translated sentences (hypotheses).
    :param references: List of reference translations.
    :return: List of COMET scores (one score per sentence pair).
    """
    try:
        # Download and load the COMET model
        model_path = download_model("Unbabel/wmt22-comet-da")
        model = load_from_checkpoint(model_path)

        # Force CPU usage if GPU is not available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        # Prepare data for COMET
        data = [
            {"src": src, "mt": mt, "ref": ref}
            for src, mt, ref in zip(source_sentences, translations, references)
        ]

        # Compute COMET scores
        results = model.predict(data, batch_size=8, gpus=0)
        scores = results["scores"]
        return scores
    except Exception as e:
        print(f"COMET calculation error: {str(e)}")
        return [0.0] * len(source_sentences)  # Return default scores on error