Spaces:
Running
Running
| from typing import List, Dict | |
| from .config import LLM_PROVIDER, HF_TOKEN, GROQ_API_KEY | |
| from huggingface_hub import InferenceClient | |
| from groq import Groq | |
| class LLMClient: | |
| def __init__(self, model: str, is_chat: bool = True): | |
| self.provider = LLM_PROVIDER | |
| self.model = model | |
| self.is_chat = is_chat | |
| if self.provider == "hf": | |
| if not HF_TOKEN: | |
| raise RuntimeError("HF_TOKEN is required for HF provider") | |
| self.client = InferenceClient(token=HF_TOKEN) | |
| elif self.provider == "groq": | |
| if not GROQ_API_KEY: | |
| raise RuntimeError("GROQ_API_KEY is required for Groq provider") | |
| self.client = Groq(api_key=GROQ_API_KEY) | |
| else: | |
| raise ValueError(f"Unsupported provider {self.provider}") | |
| def chat(self, messages: List[Dict[str, str]], max_tokens: int = 1024) -> str: | |
| if self.provider == "hf": | |
| prompt = "" | |
| for m in messages: | |
| role = m.get("role", "user") | |
| content = m.get("content", "") | |
| prompt += f"[{role.upper()}]\n{content}\n" | |
| out = self.client.text_generation( | |
| prompt, | |
| model=self.model, | |
| max_new_tokens=max_tokens, | |
| temperature=0.2, | |
| do_sample=False, | |
| ) | |
| return out | |
| else: | |
| resp = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=0.2, | |
| ) | |
| return resp.choices[0].message.content | |