Renangi's picture
Initial commit without secrets
c8dfbc0
raw
history blame
1.66 kB
from typing import List, Dict
from .config import LLM_PROVIDER, HF_TOKEN, GROQ_API_KEY
from huggingface_hub import InferenceClient
from groq import Groq
class LLMClient:
def __init__(self, model: str, is_chat: bool = True):
self.provider = LLM_PROVIDER
self.model = model
self.is_chat = is_chat
if self.provider == "hf":
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN is required for HF provider")
self.client = InferenceClient(token=HF_TOKEN)
elif self.provider == "groq":
if not GROQ_API_KEY:
raise RuntimeError("GROQ_API_KEY is required for Groq provider")
self.client = Groq(api_key=GROQ_API_KEY)
else:
raise ValueError(f"Unsupported provider {self.provider}")
def chat(self, messages: List[Dict[str, str]], max_tokens: int = 1024) -> str:
if self.provider == "hf":
prompt = ""
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
prompt += f"[{role.upper()}]\n{content}\n"
out = self.client.text_generation(
prompt,
model=self.model,
max_new_tokens=max_tokens,
temperature=0.2,
do_sample=False,
)
return out
else:
resp = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_tokens,
temperature=0.2,
)
return resp.choices[0].message.content