yuto0o
ninja移行
9332d9b
raw
history blame
1.6 kB
import torch
from ninja import NinjaAPI
from .model_loader import get_model
from .schemas import ChatInput, ChatOutput
# APIインスタンスの作成
api = NinjaAPI()
@api.post("/chat", response=ChatOutput)
def chat(request, data: ChatInput):
"""
Qwenモデルを使用したチャットAPI
"""
user_input = data.text # Schema経由で安全にアクセス
# モデルのロード(初回のみロードが走る)
model, tokenizer = get_model()
# 1. 会話フォーマットの作成
messages = [
{
"role": "system",
"content": "あなたは親切でフレンドリーなAIアシスタント「qwen」です。自然な日本語で簡潔に返事をしてください。",
},
{"role": "user", "content": user_input},
]
# 2. プロンプトへの変換
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 3. 生成
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
# 4. デコード
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# ChatOutputスキーマに合わせてdictを返す
return {"result": response_text}