Smllm / app.py
ghosthets's picture
Update app.py
05fe403 verified
import flask
from flask import request, jsonify
# Use AutoModelForCausalLM for Decoder-only models like Qwen
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Initialize the Flask application
app = flask.Flask(__name__)
# Qwen1.5-0.5B-Chat Model ID
model_id = "Qwen/Qwen1.5-0.5B-Chat"
print(f"๐Ÿ”„ Loading {model_id} model...")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load the model using the correct CausalLM class
# Using bfloat16 for better memory/speed if a compatible GPU is available
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
# Set the device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"โœ… {model_id} Model loaded successfully!")
@app.route('/chat', methods=['POST'])
def chat():
try:
data = request.get_json()
msg = data.get("message", "")
if not msg:
return jsonify({"error": "No message sent"}), 400
# --- Qwen1.5 Chat Template Formatting ---
# Qwen models require input in the ChatML format.
chat_history = [{"role": "user", "content": msg}]
# apply_chat_template handles the specific formatting (e.g., <|im_start|>user\n...)
formatted_prompt = tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=True
)
# Tokenize the formatted prompt
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
# Generation configuration
output = model.generate(
**inputs,
max_length=256,
do_sample=True,
top_p=0.8,
temperature=0.6,
# Set pad_token_id to eos_token_id, which is often necessary for Causal LMs
pad_token_id=tokenizer.eos_token_id
)
# Decode the full output
full_reply = tokenizer.decode(output[0], skip_special_tokens=False)
# --- Extract only the Generated Response ---
# Qwen ChatML format uses '<|im_start|>assistant\n' before the response
assistant_tag = "<|im_start|>assistant\n"
if assistant_tag in full_reply:
# Split the full reply and take the content after the assistant tag
reply = full_reply.split(assistant_tag)[-1].strip()
# Remove the end-of-message tag if it was generated
if "<|im_end|>" in reply:
reply = reply.split("<|im_end|>")[0].strip()
else:
# Fallback: Decode only the newly generated tokens
reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
return jsonify({"reply": reply})
except Exception as e:
# Catch any runtime errors
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# Run the Flask app
app.run(host='0.0.0.0', port=7860)