UserLM / app.py
pszemraj's picture
docs
89c5c79 verified
from __future__ import annotations
import os
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
# ======================
# Config
# ======================
MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b")
DEFAULT_SYSTEM_PROMPT = (
"You are a user who wants to compute rolling 7-day averages over uneven time stamps. "
"You are suspicious of resampling magic and will accuse the assistant of witchcraft if it's not explicit."
)
# ======================
# Load model
# ======================
def load_model(model_id: str = MODEL_ID):
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Special tokens
eot = "<|eot_id|>"
end_conv = "<|endconversation|>"
eot_ids = tok.encode(eot, add_special_tokens=False)
end_conv_ids = tok.encode(end_conv, add_special_tokens=False)
eos_token_id = eot_ids[0] if len(eot_ids) > 0 else tok.eos_token_id
bad_words_ids = [[tid] for tid in end_conv_ids] if len(end_conv_ids) > 0 else None
# Guardrail 1: problematic first tokens (Appendix C.1)
prob_first_tokens = ["I", "You", "Here", "i", "you", "here"]
first_token_filter_ids = []
for w in prob_first_tokens:
ids = tok.encode(w, add_special_tokens=False)
if ids:
first_token_filter_ids.append(ids[0])
return tok, mdl, eos_token_id, bad_words_ids, first_token_filter_ids
tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model()
model.generation_config.eos_token_id = EOS_TOKEN_ID
model.generation_config.pad_token_id = tokenizer.eos_token_id
model.eval()
# ======================
# Guardrail helpers
# ======================
def is_valid_length(text: str, min_words: int = 3, max_words: int = 25) -> bool:
wc = len(text.split())
return min_words <= wc <= max_words
def is_verbatim_repetition(
new_text: str, history_pairs: List[Tuple[str, Optional[str]]], system_prompt: str
) -> bool:
t = new_text.strip().lower()
if t == system_prompt.strip().lower():
return True
for model_user, _ in history_pairs:
if model_user and t == model_user.strip().lower():
return True
return False
class ForbidFirstToken(LogitsProcessor):
"""Set -inf on a token list for the *first* generated token only."""
def __init__(self, forbid_ids: List[int], prompt_len: int):
self.forbid = list(set(int(x) for x in forbid_ids))
self.prompt_len = int(prompt_len)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
# Apply only when generating the very first token (seq len == prompt_len)
if input_ids.shape[1] == self.prompt_len and self.forbid:
scores[:, self.forbid] = float("-inf")
return scores
# ======================
# Message utilities
# ======================
def build_hf_messages(
system_prompt: str, history_pairs: List[Tuple[str, Optional[str]]]
) -> List[Dict[str, str]]:
"""
Construct messages for tokenizer.apply_chat_template.
history_pairs = list of (model_user, human_assistant)
"""
msgs: List[Dict[str, str]] = []
if system_prompt.strip():
msgs.append({"role": "system", "content": system_prompt.strip()})
for model_user, human_assistant in history_pairs:
if model_user:
msgs.append({"role": "user", "content": model_user})
if human_assistant:
msgs.append({"role": "assistant", "content": human_assistant})
return msgs
def pairs_to_ui_messages(
history_pairs: List[Tuple[str, Optional[str]]]
) -> List[Dict[str, str]]:
"""
Convert (model_user, human_assistant) pairs to Gradio Chatbot(type='messages') UI messages.
Visual convention:
- LEFT (role='assistant'): UserLM's utterances (the simulator)
- RIGHT (role='user'): Your replies (you play the assistant)
"""
ui: List[Dict[str, str]] = []
for model_user, human_assistant in history_pairs:
if model_user:
ui.append({"role": "assistant", "content": model_user})
if human_assistant:
ui.append({"role": "user", "content": human_assistant})
return ui
# ======================
# Generation
# ======================
@spaces.GPU
def generate_reply(
system_prompt: str,
history_pairs: List[Tuple[str, Optional[str]]],
max_new_tokens: int = 128,
temperature: float = 1.0,
top_p: float = 0.8,
max_retries: int = 10,
) -> str:
"""Implements the 4 guardrails from Appendix C.1 and passes an explicit attention_mask."""
messages = build_hf_messages(system_prompt, history_pairs)
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True
).to(model.device)
# Robust attention mask even when pad_token_id == eos_token_id.
# If no padding is present (usual single-sequence case), use all-ones.
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
if pad_id is not None and (inputs == pad_id).any():
attention_mask = (inputs != pad_id).long()
else:
attention_mask = torch.ones_like(inputs, dtype=torch.long)
for _ in range(max_retries):
lp = LogitsProcessorList(
[ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])]
)
with torch.no_grad():
out = model.generate(
input_ids=inputs,
attention_mask=attention_mask, # <-- explicit mask to silence warning & be robust
do_sample=True,
top_p=top_p,
temperature=temperature,
max_new_tokens=max_new_tokens,
eos_token_id=EOS_TOKEN_ID,
pad_token_id=tokenizer.eos_token_id,
bad_words_ids=BAD_WORDS_IDS, # Guardrail 2: block <|endconversation|>
logits_processor=lp, # Guardrail 1: first-token filter
)
gen = out[0][inputs.shape[1]:]
text = tokenizer.decode(gen, skip_special_tokens=True).strip()
# Guardrails 3 & 4
if not is_valid_length(text, min_words=3, max_words=25):
continue
if is_verbatim_repetition(text, history_pairs, system_prompt):
continue
return text
raise RuntimeError("Failed to generate a valid user utterance after retries.")
# ======================
# Gradio UI
# ======================
def respond(
your_reply: str,
history_pairs: List[Tuple[str, Optional[str]]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
):
# First turn: ignore your_reply and generate the initial UserLM utterance
if not history_pairs:
userlm = generate_reply(
system_prompt,
[],
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
history_pairs = [(userlm, None)]
return pairs_to_ui_messages(history_pairs), history_pairs, ""
# Subsequent turns require your reply
if not your_reply.strip():
gr.Info("Type your (assistant) reply on the right, then click Generate.")
return pairs_to_ui_messages(history_pairs), history_pairs, ""
# Close the last pair with your reply
last_userlm, _ = history_pairs[-1]
history_pairs[-1] = (last_userlm, your_reply.strip())
# Generate the next UserLM utterance
userlm = generate_reply(
system_prompt,
history_pairs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
history_pairs.append((userlm, None))
return pairs_to_ui_messages(history_pairs), history_pairs, ""
def _clear():
return [], [], DEFAULT_SYSTEM_PROMPT, ""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# UserLM-8b: User Language Model Demo
**Model:** `{MODEL_ID}`
The AI plays the **user**, you play the **assistant**. Your messages appear on the **right**.
"""
)
system_box = gr.Textbox(
label="User Intent",
value=DEFAULT_SYSTEM_PROMPT,
lines=3,
placeholder="Enter the user's goal or intent",
)
# Use messages format so we can control left/right explicitly
chatbot = gr.Chatbot(
label="Conversation",
height=420,
type="messages", # modern format; tuples are deprecated
render_markdown=True,
autoscroll=True,
show_copy_button=True,
# You can set avatar images like: avatar_images=("assets/you.png", "assets/userlm.png")
)
# Your reply box (you play the assistant)
msg = gr.Textbox(
label="Your Reply (assistant)",
placeholder="Type your assistant response here…",
info="Leave blank & press _Generate_ to create the **first user message**.",
lines=2,
)
with gr.Accordion("Generation Settings", open=False):
max_new_tokens = gr.Slider(16, 512, value=128, step=16, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
with gr.Row():
submit_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
# Internal state keeps the compact (userLM, you) pairs used for decoding
history_pairs_state = gr.State([]) # List[Tuple[str, Optional[str]]]
with gr.Accordion("Implementation Details", open=False):
gr.Markdown(
"""
- Decoding defaults from [the model card](https://hf.co/microsoft/UserLM-8b): `temperature=1.0`, `top_p=0.8`, stop on `<|eot_id|>`, and block `<|endconversation|>`.
- Guardrails from Appendix C.1 [of the paper](https://arxiv.org/abs/2510.06552): (1) first-token logit filter, (2) block endconversation, (3) 3–25 word length, (4) verbatim repetition filter.
"""
)
def _submit(your_text, pairs, sys_prompt, mnt, temp, tp):
ui_msgs, new_pairs, cleared_text = respond(
your_text, pairs, sys_prompt, mnt, temp, tp
)
return ui_msgs, new_pairs, cleared_text
submit_btn.click(
fn=_submit,
inputs=[
msg,
history_pairs_state,
system_box,
max_new_tokens,
temperature,
top_p,
],
outputs=[chatbot, history_pairs_state, msg],
)
msg.submit(
fn=_submit,
inputs=[
msg,
history_pairs_state,
system_box,
max_new_tokens,
temperature,
top_p,
],
outputs=[chatbot, history_pairs_state, msg],
)
clear_btn.click(
fn=_clear,
outputs=[chatbot, history_pairs_state, system_box, msg],
)
if __name__ == "__main__":
demo.queue().launch()