|
|
import os |
|
|
from typing import Any, List, Optional |
|
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun |
|
|
from langchain_core.language_models.llms import LLM |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
try: |
|
|
from mlx_lm import load, generate |
|
|
HAS_MLX = True |
|
|
except ImportError: |
|
|
HAS_MLX = False |
|
|
|
|
|
|
|
|
class MLXLLM(LLM): |
|
|
"""Custom LangChain Wrapper for MLX Models (with Cloud Fallback)""" |
|
|
|
|
|
model_id: str = os.getenv("MODEL_ID", "mlx-community/Llama-3.2-3B-Instruct-4bit") |
|
|
model: Any = None |
|
|
tokenizer: Any = None |
|
|
max_tokens: int = int(os.getenv("MAX_TOKENS", 512)) |
|
|
pipeline: Any = None |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if HAS_MLX: |
|
|
print(f"π Loading MLX Model: {self.model_id}") |
|
|
self.model, self.tokenizer = load(self.model_id) |
|
|
else: |
|
|
print(f"β οΈ MLX not found. Falling back to HuggingFace Transformers (CPU/Cloud).") |
|
|
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
cloud_model_id = os.getenv("MODEL_ID", "gpt2") |
|
|
|
|
|
self.pipeline = pipeline( |
|
|
"text-generation", |
|
|
model=cloud_model_id, |
|
|
max_new_tokens=self.max_tokens |
|
|
) |
|
|
|
|
|
@property |
|
|
def _llm_type(self) -> str: |
|
|
return "mlx_llama" if HAS_MLX else "transformers_fallback" |
|
|
|
|
|
def _call( |
|
|
self, |
|
|
prompt: str, |
|
|
stop: Optional[List[str]] = None, |
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
|
**kwargs: Any, |
|
|
) -> str: |
|
|
if stop is not None: |
|
|
raise ValueError("stop kwargs are not permitted.") |
|
|
|
|
|
if HAS_MLX: |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
formatted_prompt = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
response = generate( |
|
|
self.model, |
|
|
self.tokenizer, |
|
|
prompt=formatted_prompt, |
|
|
verbose=False, |
|
|
max_tokens=self.max_tokens |
|
|
) |
|
|
return response |
|
|
else: |
|
|
|
|
|
|
|
|
response = self.pipeline(prompt)[0]['generated_text'] |
|
|
|
|
|
return response[len(prompt):] |
|
|
|