| | from pylate import models |
| | from transformers import AutoTokenizer |
| | import torch |
| | import numpy as np |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) |
| | self.model = models.ColBERT(model_name_or_path=path) |
| | self.model.eval() |
| |
|
| | def _to_list(self, emb): |
| | """ |
| | Make the output JSON-serialisable: |
| | β torch.Tensor β emb.cpu().tolist() |
| | β np.ndarray β emb.tolist() |
| | β list[...] β recurse |
| | """ |
| | if isinstance(emb, torch.Tensor): |
| | return emb.cpu().tolist() |
| | if isinstance(emb, np.ndarray): |
| | return emb.tolist() |
| | if isinstance(emb, list): |
| | return [self._to_list(e) for e in emb] |
| | return emb |
| |
|
| | def __call__(self, data): |
| | texts = data.get("inputs") or data.get("text") or data |
| | if isinstance(texts, str): |
| | texts = [texts] |
| |
|
| | with torch.no_grad(): |
| | emb = self.model.encode( |
| | texts, |
| | is_query=True, |
| | batch_size=32, |
| | ) |
| |
|
| | return self._to_list(emb) |
| |
|