ghosthets commited on
Commit
05fe403
·
verified ·
1 Parent(s): 02f2b8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -22
app.py CHANGED
@@ -1,23 +1,29 @@
1
  import flask
2
  from flask import request, jsonify
3
- # Use AutoModelForCausalLM for Decoder-only models like TinyLlama
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
 
7
  app = flask.Flask(__name__)
8
 
9
- model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
10
 
11
- print("🔄 Loading TinyLlama model...")
12
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- # Load using AutoModelForCausalLM
15
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) # Using bfloat16 for better memory/speed on GPU
16
 
 
 
 
 
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  model.to(device)
19
 
20
- print("✅ Model loaded instantly!")
21
 
22
  @app.route('/chat', methods=['POST'])
23
  def chat():
@@ -28,44 +34,57 @@ def chat():
28
  if not msg:
29
  return jsonify({"error": "No message sent"}), 400
30
 
31
- # --- Key Change 1: Apply Chat Template ---
32
- # Format the user message into the model's required chat template
33
  chat_history = [{"role": "user", "content": msg}]
34
- # add_generation_prompt=True ensures the model knows it needs to respond
35
- formatted_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
36
 
37
  # Tokenize the formatted prompt
38
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
39
 
40
- # Generation
41
  output = model.generate(
42
  **inputs,
43
  max_length=256,
44
  do_sample=True,
45
- top_p=0.9,
46
- temperature=0.7,
47
- eos_token_id=tokenizer.eos_token_id
 
48
  )
49
 
50
- # Decode the output
51
  full_reply = tokenizer.decode(output[0], skip_special_tokens=False)
52
 
53
- # --- Key Change 2: Extract only the generated response ---
54
- # The output includes the input prompt, so we extract only the response part.
 
 
55
 
56
- # Identify the assistant marker used by TinyLlama's chat template
57
- if "[/INST]" in full_reply:
58
- # This structure is often used: <s>[INST] User Prompt [/INST] Assistant Reply
59
- reply = full_reply.split("[/INST]")[-1].strip()
 
 
 
60
  else:
61
- # Fallback: decode only the newly generated tokens
62
  reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
63
 
64
  return jsonify({"reply": reply})
65
 
66
  except Exception as e:
 
67
  return jsonify({"error": str(e)}), 500
68
 
69
 
70
  if __name__ == "__main__":
 
71
  app.run(host='0.0.0.0', port=7860)
 
1
  import flask
2
  from flask import request, jsonify
3
+ # Use AutoModelForCausalLM for Decoder-only models like Qwen
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
7
+ # Initialize the Flask application
8
  app = flask.Flask(__name__)
9
 
10
+ # Qwen1.5-0.5B-Chat Model ID
11
+ model_id = "Qwen/Qwen1.5-0.5B-Chat"
12
 
13
+ print(f"🔄 Loading {model_id} model...")
14
 
15
+ # Load the tokenizer
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
17
 
18
+ # Load the model using the correct CausalLM class
19
+ # Using bfloat16 for better memory/speed if a compatible GPU is available
20
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
21
+
22
+ # Set the device (GPU/CPU)
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model.to(device)
25
 
26
+ print(f"✅ {model_id} Model loaded successfully!")
27
 
28
  @app.route('/chat', methods=['POST'])
29
  def chat():
 
34
  if not msg:
35
  return jsonify({"error": "No message sent"}), 400
36
 
37
+ # --- Qwen1.5 Chat Template Formatting ---
38
+ # Qwen models require input in the ChatML format.
39
  chat_history = [{"role": "user", "content": msg}]
40
+
41
+ # apply_chat_template handles the specific formatting (e.g., <|im_start|>user\n...)
42
+ formatted_prompt = tokenizer.apply_chat_template(
43
+ chat_history,
44
+ tokenize=False,
45
+ add_generation_prompt=True
46
+ )
47
 
48
  # Tokenize the formatted prompt
49
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
50
 
51
+ # Generation configuration
52
  output = model.generate(
53
  **inputs,
54
  max_length=256,
55
  do_sample=True,
56
+ top_p=0.8,
57
+ temperature=0.6,
58
+ # Set pad_token_id to eos_token_id, which is often necessary for Causal LMs
59
+ pad_token_id=tokenizer.eos_token_id
60
  )
61
 
62
+ # Decode the full output
63
  full_reply = tokenizer.decode(output[0], skip_special_tokens=False)
64
 
65
+ # --- Extract only the Generated Response ---
66
+
67
+ # Qwen ChatML format uses '<|im_start|>assistant\n' before the response
68
+ assistant_tag = "<|im_start|>assistant\n"
69
 
70
+ if assistant_tag in full_reply:
71
+ # Split the full reply and take the content after the assistant tag
72
+ reply = full_reply.split(assistant_tag)[-1].strip()
73
+
74
+ # Remove the end-of-message tag if it was generated
75
+ if "<|im_end|>" in reply:
76
+ reply = reply.split("<|im_end|>")[0].strip()
77
  else:
78
+ # Fallback: Decode only the newly generated tokens
79
  reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
80
 
81
  return jsonify({"reply": reply})
82
 
83
  except Exception as e:
84
+ # Catch any runtime errors
85
  return jsonify({"error": str(e)}), 500
86
 
87
 
88
  if __name__ == "__main__":
89
+ # Run the Flask app
90
  app.run(host='0.0.0.0', port=7860)