pszemraj commited on
Commit
5bc353f
·
verified ·
1 Parent(s): 8f03206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -52
app.py CHANGED
@@ -1,11 +1,12 @@
1
  from __future__ import annotations
 
2
  import os
3
- from typing import List, Tuple, Dict, Any
4
- import spaces
5
 
6
  import gradio as gr
 
7
  import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
  # ----------------------
11
  # Config
@@ -19,6 +20,7 @@ DEFAULT_SYSTEM_PROMPT = (
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
 
22
  def load_model(model_id: str = MODEL_ID):
23
  """Load tokenizer and model, with a reasonable dtype and device fallback."""
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
@@ -36,16 +38,25 @@ def load_model(model_id: str = MODEL_ID):
36
  end_token_ids = tokenizer.encode(end_token, add_special_tokens=False)
37
  end_conv_token_ids = tokenizer.encode(end_conv_token, add_special_tokens=False)
38
 
39
- # Some models may not include these tokens handle gracefully
40
- eos_token_id = end_token_ids[0] if len(end_token_ids) > 0 else tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
41
  bad_words_ids = (
42
  [[tid] for tid in end_conv_token_ids] if len(end_conv_token_ids) > 0 else None
43
  )
44
 
45
- return tokenizer, model, eos_token_id, bad_words_ids
46
 
47
 
48
- tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS = load_model()
49
  model = model.to(device)
50
  model.eval()
51
 
@@ -53,7 +64,10 @@ model.eval()
53
  # Generation helper
54
  # ----------------------
55
 
56
- def build_messages(system_prompt: str, history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
 
 
 
57
  """Transform Gradio history [(user, assistant), ...] into chat template messages."""
58
  messages: List[Dict[str, str]] = []
59
  if system_prompt.strip():
@@ -66,51 +80,123 @@ def build_messages(system_prompt: str, history: List[Tuple[str, str]]) -> List[D
66
  return messages
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  @spaces.GPU
70
  def generate_reply(
71
  messages: List[Dict[str, str]],
 
 
72
  max_new_tokens: int = 256,
73
- temperature: float = 0.8,
74
- top_p: float = 0.9,
 
75
  ) -> str:
76
- """Run a single generate() step and return the model's text reply."""
77
- # Prepare input ids using the model's chat template
78
- inputs = tokenizer.apply_chat_template(
79
- messages,
80
- return_tensors="pt",
81
- add_generation_prompt=True,
82
- ).to(device)
83
-
84
- with torch.no_grad():
85
- outputs = model.generate(
86
- input_ids=inputs,
87
- do_sample=True,
88
- top_p=top_p,
89
- temperature=temperature,
90
- max_new_tokens=max_new_tokens,
91
- eos_token_id=EOS_TOKEN_ID,
92
- pad_token_id=tokenizer.eos_token_id,
93
- bad_words_ids=BAD_WORDS_IDS,
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # Slice off the prompt tokens to get only the new text
97
- generated = outputs[0][inputs.shape[1]:]
98
- text = tokenizer.decode(generated, skip_special_tokens=True).strip()
99
- return text
100
 
101
 
102
  # ----------------------
103
  # Gradio UI callbacks
104
  # ----------------------
105
 
106
- def respond(user_message: str, chat_history: List[Tuple[str, str]], system_prompt: str,
107
- max_new_tokens: int, temperature: float, top_p: float):
 
 
 
 
 
 
 
108
  # Build messages including prior turns
109
  messages = build_messages(system_prompt, chat_history + [(user_message, "")])
110
 
111
  try:
112
  reply = generate_reply(
113
  messages,
 
 
114
  max_new_tokens=max_new_tokens,
115
  temperature=temperature,
116
  top_p=top_p,
@@ -130,45 +216,90 @@ def clear_state():
130
  # Build the Gradio App
131
  # ----------------------
132
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
- gr.Markdown("""
134
- # 🧪 Transformers × Gradio: Multi‑turn Chat Demo
 
 
 
135
 
136
- Model: **{model}** on **{device}**
 
 
 
 
137
 
138
- Change the system prompt, then chat. Sliders control sampling.
139
- """.format(model=MODEL_ID, device=device))
 
 
140
 
141
  with gr.Row():
142
  system_box = gr.Textbox(
143
- label="System Prompt",
144
  value=DEFAULT_SYSTEM_PROMPT,
145
  lines=3,
146
- placeholder="Enter a system instruction to steer the assistant",
147
  )
148
 
149
- chatbot = gr.Chatbot(height=420, label="Chat")
150
 
151
  with gr.Row():
152
  msg = gr.Textbox(
153
- label="Your message",
154
- placeholder="Type a message and press Enter",
 
155
  )
156
 
157
- with gr.Accordion("Generation Settings", open=False):
158
- max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
159
- temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="temperature")
160
- top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  with gr.Row():
163
- submit_btn = gr.Button("Send", variant="primary")
164
  clear_btn = gr.Button("Clear")
165
 
166
  state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
167
 
 
 
 
 
 
 
 
 
 
 
168
  def _submit(user_text, history, system_prompt, mnt, temp, tp):
169
  if not user_text or not user_text.strip():
170
  return gr.update(), history
171
- new_history, visible = respond(user_text.strip(), history, system_prompt, mnt, temp, tp)
 
 
172
  return "", visible
173
 
174
  submit_btn.click(
@@ -195,4 +326,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
195
  clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
196
 
197
  if __name__ == "__main__":
198
- demo.queue().launch() # enable queuing for concurrency
 
1
  from __future__ import annotations
2
+
3
  import os
4
+ from typing import Any, Dict, List, Tuple
 
5
 
6
  import gradio as gr
7
+ import spaces
8
  import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
  # ----------------------
12
  # Config
 
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+
24
  def load_model(model_id: str = MODEL_ID):
25
  """Load tokenizer and model, with a reasonable dtype and device fallback."""
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
38
  end_token_ids = tokenizer.encode(end_token, add_special_tokens=False)
39
  end_conv_token_ids = tokenizer.encode(end_conv_token, add_special_tokens=False)
40
 
41
+ # Guardrail 1: Problematic first tokens that cause repetition (from Appendix C.1)
42
+ problematic_tokens = ["I", "You", "Here", "i", "you", "here"]
43
+ first_token_filter_ids = []
44
+ for token in problematic_tokens:
45
+ token_ids = tokenizer.encode(token, add_special_tokens=False)
46
+ if len(token_ids) > 0:
47
+ first_token_filter_ids.append(token_ids[0])
48
+
49
+ eos_token_id = (
50
+ end_token_ids[0] if len(end_token_ids) > 0 else tokenizer.eos_token_id
51
+ )
52
  bad_words_ids = (
53
  [[tid] for tid in end_conv_token_ids] if len(end_conv_token_ids) > 0 else None
54
  )
55
 
56
+ return tokenizer, model, eos_token_id, bad_words_ids, first_token_filter_ids
57
 
58
 
59
+ tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model()
60
  model = model.to(device)
61
  model.eval()
62
 
 
64
  # Generation helper
65
  # ----------------------
66
 
67
+
68
+ def build_messages(
69
+ system_prompt: str, history: List[Tuple[str, str]]
70
+ ) -> List[Dict[str, str]]:
71
  """Transform Gradio history [(user, assistant), ...] into chat template messages."""
72
  messages: List[Dict[str, str]] = []
73
  if system_prompt.strip():
 
80
  return messages
81
 
82
 
83
+ def apply_first_token_filter(
84
+ logits: torch.Tensor, filter_ids: List[int]
85
+ ) -> torch.Tensor:
86
+ """Apply logit filter for problematic first tokens (Guardrail 1)."""
87
+ logits_filtered = logits.clone()
88
+ for token_id in filter_ids:
89
+ logits_filtered[0, -1, token_id] = float("-inf")
90
+ return logits_filtered
91
+
92
+
93
+ def is_valid_length(text: str, min_words: int = 3, max_words: int = 50) -> bool:
94
+ """Check if generated text meets length requirements (Guardrail 3).
95
+
96
+ Paper used max_words=25 for their simulation experiments, but we use 50
97
+ for interactive demo to allow slightly longer responses while still preventing
98
+ the model from revealing the entire intent at once.
99
+ """
100
+ word_count = len(text.split())
101
+ return min_words <= word_count <= max_words
102
+
103
+
104
+ def is_verbatim_repetition(
105
+ new_text: str, history: List[Tuple[str, str]], system_prompt: str
106
+ ) -> bool:
107
+ """Check if text is exact repetition of prior user turn or system prompt (Guardrail 4)."""
108
+ new_text_normalized = new_text.strip().lower()
109
+
110
+ # Check against system prompt
111
+ if new_text_normalized == system_prompt.strip().lower():
112
+ return True
113
+
114
+ # Check against previous user messages
115
+ for user_msg, _ in history:
116
+ if user_msg and new_text_normalized == user_msg.strip().lower():
117
+ return True
118
+
119
+ return False
120
+
121
+
122
  @spaces.GPU
123
  def generate_reply(
124
  messages: List[Dict[str, str]],
125
+ history: List[Tuple[str, str]],
126
+ system_prompt: str,
127
  max_new_tokens: int = 256,
128
+ temperature: float = 1.0,
129
+ top_p: float = 0.8,
130
+ max_retries: int = 5,
131
  ) -> str:
132
+ """Run generation with guardrails from Appendix C.1.
133
+
134
+ Implements all 4 guardrails from the paper:
135
+ 1. Filter problematic first tokens
136
+ 2. Optionally avoid dialogue termination (disabled by default for demo)
137
+ 3. Enforce length thresholds with retry
138
+ 4. Filter verbatim repetitions with retry
139
+ """
140
+
141
+ for attempt in range(max_retries):
142
+ # Prepare input ids using the model's chat template
143
+ inputs = tokenizer.apply_chat_template(
144
+ messages,
145
+ return_tensors="pt",
146
+ add_generation_prompt=True,
147
+ ).to(device)
148
+
149
+ with torch.no_grad():
150
+ outputs = model.generate(
151
+ input_ids=inputs,
152
+ do_sample=True,
153
+ top_p=top_p,
154
+ temperature=temperature,
155
+ max_new_tokens=max_new_tokens,
156
+ eos_token_id=EOS_TOKEN_ID,
157
+ pad_token_id=tokenizer.eos_token_id,
158
+ bad_words_ids=BAD_WORDS_IDS, # Prevents <|endconversation|>
159
+ )
160
+
161
+ # Slice off the prompt tokens to get only the new text
162
+ generated = outputs[0][inputs.shape[1] :]
163
+ text = tokenizer.decode(generated, skip_special_tokens=True).strip()
164
+
165
+ # Apply guardrails - retry if checks fail
166
+ if not is_valid_length(text):
167
+ continue
168
+
169
+ if is_verbatim_repetition(text, history, system_prompt):
170
+ continue
171
+
172
+ # Success - return the valid text
173
+ return text
174
 
175
+ # If all retries failed, return a fallback message
176
+ return "(Unable to generate valid response after multiple attempts)"
 
 
177
 
178
 
179
  # ----------------------
180
  # Gradio UI callbacks
181
  # ----------------------
182
 
183
+
184
+ def respond(
185
+ user_message: str,
186
+ chat_history: List[Tuple[str, str]],
187
+ system_prompt: str,
188
+ max_new_tokens: int,
189
+ temperature: float,
190
+ top_p: float,
191
+ ):
192
  # Build messages including prior turns
193
  messages = build_messages(system_prompt, chat_history + [(user_message, "")])
194
 
195
  try:
196
  reply = generate_reply(
197
  messages,
198
+ chat_history,
199
+ system_prompt,
200
  max_new_tokens=max_new_tokens,
201
  temperature=temperature,
202
  top_p=top_p,
 
216
  # Build the Gradio App
217
  # ----------------------
218
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown(
220
+ f"""
221
+ # UserLM-8b: User Language Model Demo
222
+
223
+ **Model:** `{MODEL_ID}` on **{device}**
224
 
225
+ This demo implements the generation guardrails from [Appendix C.1](https://arxiv.org/abs/2510.06552) of the paper:
226
+ - Filters problematic first tokens (I, You, Here) that cause repetition
227
+ - Enforces length thresholds (3-50 words per turn)
228
+ - Prevents verbatim repetition of prior turns
229
+ - Uses recommended sampling params: temp=1.0, top_p=0.8
230
 
231
+ **Note:** Unlike typical assistant LMs, UserLM simulates *human users* in conversations.
232
+ The system prompt defines the user's high-level intent.
233
+ """
234
+ )
235
 
236
  with gr.Row():
237
  system_box = gr.Textbox(
238
+ label="User Intent (System Prompt)",
239
  value=DEFAULT_SYSTEM_PROMPT,
240
  lines=3,
241
+ placeholder="Enter a high-level user intent (e.g., 'You are a user who wants to...')",
242
  )
243
 
244
+ chatbot = gr.Chatbot(height=420, label="Simulated User-Assistant Conversation")
245
 
246
  with gr.Row():
247
  msg = gr.Textbox(
248
+ label="Assistant Response",
249
+ placeholder="Type the assistant's response to the user",
250
+ lines=2,
251
  )
252
 
253
+ with gr.Accordion(
254
+ "Generation Settings (Based on Paper Recommendations)", open=False
255
+ ):
256
+ max_new_tokens = gr.Slider(
257
+ 16,
258
+ 512,
259
+ value=256,
260
+ step=16,
261
+ label="max_new_tokens",
262
+ info="Max tokens per user turn. Paper used stricter limits for simulation.",
263
+ )
264
+ temperature = gr.Slider(
265
+ 0.0,
266
+ 2.0,
267
+ value=1.0,
268
+ step=0.05,
269
+ label="temperature",
270
+ info="Paper recommends 1.0 for realistic user diversity",
271
+ )
272
+ top_p = gr.Slider(
273
+ 0.0,
274
+ 1.0,
275
+ value=0.8,
276
+ step=0.01,
277
+ label="top_p",
278
+ info="Paper recommends 0.8 (not 0.9)",
279
+ )
280
 
281
  with gr.Row():
282
+ submit_btn = gr.Button("Generate User Response", variant="primary")
283
  clear_btn = gr.Button("Clear")
284
 
285
  state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
286
 
287
+ gr.Markdown(
288
+ """
289
+ ### Usage Tips:
290
+ - The **system prompt** defines the user's goal (keep it high-level, not overly specific)
291
+ - Type what the **assistant says** in response
292
+ - Click **Generate User Response** to simulate how a human user would reply
293
+ - UserLM naturally reveals intent across multiple turns, not all at once
294
+ """
295
+ )
296
+
297
  def _submit(user_text, history, system_prompt, mnt, temp, tp):
298
  if not user_text or not user_text.strip():
299
  return gr.update(), history
300
+ new_history, visible = respond(
301
+ user_text.strip(), history, system_prompt, mnt, temp, tp
302
+ )
303
  return "", visible
304
 
305
  submit_btn.click(
 
326
  clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
327
 
328
  if __name__ == "__main__":
329
+ demo.queue().launch()