File size: 3,147 Bytes
84b39fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import warnings
from typing import Iterable, List, Optional

import torch

from fairseq2.generation import (
    BeamSearchSeq2SeqGenerator,
    Sampler,
    SamplingSeq2SeqGenerator,
    Seq2SeqGenerator,
    SequenceToTextConverter,
)

from sonar.inference_pipelines.utils import add_progress_bar
from sonar.inference_pipelines.text import (
    EmbeddingToTextModelPipeline as _BaseEmbeddingToTextModelPipeline,
)
from fairseq2.data.data_pipeline import read_sequence


class EmbeddingToTextModelPipeline(_BaseEmbeddingToTextModelPipeline):
    """Drop-in replacement that can also return sentence log-probabilities via return_scores.

    - When return_scores=False (default), behaves exactly like the base pipeline and returns List[str].
    - When return_scores=True, returns a tuple (List[str], List[float]) where each float is the
      hypothesis score from fairseq2 (sum of token log-probabilities if normalize_scores=False,
      otherwise length-normalized per fairseq2 semantics).
    """

    @torch.inference_mode()
    def predict(
        self,
        inputs: torch.Tensor,
        target_lang: str,
        batch_size: int = 5,
        progress_bar: bool = False,
        sampler: Optional[Sampler] = None,
        return_scores: bool = False,
        **generator_kwargs,
    ):
        if sampler is not None:
            generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator(
                self.model, sampler, **generator_kwargs
            )
        else:
            generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs)

        converter = SequenceToTextConverter(
            generator,
            self.tokenizer,
            task="translation",
            target_lang=target_lang,
        )

        def _do_translate(src_tensors: List[torch.Tensor]):
            texts, gen_out = converter.batch_convert(
                torch.stack(src_tensors).to(self.device), None
            )
            if return_scores:
                scores: List[float] = []
                for hyps in gen_out.hypotheses:
                    if len(hyps) == 0 or hyps[0].score is None:
                        scores.append(0.0)
                    else:
                        scores.append(float(hyps[0].score))
                return texts, scores
            return texts

        pipeline: Iterable = (
            read_sequence(list(inputs))
            .bucket(batch_size)
            .map(_do_translate)
            .and_return()
        )

        if progress_bar:
            pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size)

        results: List = list(iter(pipeline))

        if not return_scores:
            # results is List[List[str]] → flatten
            return [text for batch_texts in results for text in batch_texts]

        # results is List[Tuple[List[str], List[float]]] → flatten both
        all_texts: List[str] = []
        all_scores: List[float] = []
        for batch in results:
            batch_texts, batch_scores = batch
            all_texts.extend(batch_texts)
            all_scores.extend(batch_scores)
        return all_texts, all_scores