|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b") |
|
|
DEFAULT_SYSTEM_PROMPT = ( |
|
|
"You are a user who wants to compute rolling 7-day averages over uneven time stamps. " |
|
|
"You are suspicious of resampling magic and will accuse the assistant of witchcraft if it's not explicit." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model_id: str = MODEL_ID): |
|
|
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
mdl = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
|
|
|
eot = "<|eot_id|>" |
|
|
end_conv = "<|endconversation|>" |
|
|
eot_ids = tok.encode(eot, add_special_tokens=False) |
|
|
end_conv_ids = tok.encode(end_conv, add_special_tokens=False) |
|
|
eos_token_id = eot_ids[0] if len(eot_ids) > 0 else tok.eos_token_id |
|
|
bad_words_ids = [[tid] for tid in end_conv_ids] if len(end_conv_ids) > 0 else None |
|
|
|
|
|
|
|
|
prob_first_tokens = ["I", "You", "Here", "i", "you", "here"] |
|
|
first_token_filter_ids = [] |
|
|
for w in prob_first_tokens: |
|
|
ids = tok.encode(w, add_special_tokens=False) |
|
|
if ids: |
|
|
first_token_filter_ids.append(ids[0]) |
|
|
|
|
|
return tok, mdl, eos_token_id, bad_words_ids, first_token_filter_ids |
|
|
|
|
|
|
|
|
tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS, FIRST_TOKEN_FILTER_IDS = load_model() |
|
|
model.generation_config.eos_token_id = EOS_TOKEN_ID |
|
|
model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_length(text: str, min_words: int = 3, max_words: int = 25) -> bool: |
|
|
wc = len(text.split()) |
|
|
return min_words <= wc <= max_words |
|
|
|
|
|
|
|
|
def is_verbatim_repetition( |
|
|
new_text: str, history_pairs: List[Tuple[str, Optional[str]]], system_prompt: str |
|
|
) -> bool: |
|
|
t = new_text.strip().lower() |
|
|
if t == system_prompt.strip().lower(): |
|
|
return True |
|
|
for model_user, _ in history_pairs: |
|
|
if model_user and t == model_user.strip().lower(): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
class ForbidFirstToken(LogitsProcessor): |
|
|
"""Set -inf on a token list for the *first* generated token only.""" |
|
|
|
|
|
def __init__(self, forbid_ids: List[int], prompt_len: int): |
|
|
self.forbid = list(set(int(x) for x in forbid_ids)) |
|
|
self.prompt_len = int(prompt_len) |
|
|
|
|
|
def __call__( |
|
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
|
|
) -> torch.FloatTensor: |
|
|
|
|
|
if input_ids.shape[1] == self.prompt_len and self.forbid: |
|
|
scores[:, self.forbid] = float("-inf") |
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_hf_messages( |
|
|
system_prompt: str, history_pairs: List[Tuple[str, Optional[str]]] |
|
|
) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Construct messages for tokenizer.apply_chat_template. |
|
|
history_pairs = list of (model_user, human_assistant) |
|
|
""" |
|
|
msgs: List[Dict[str, str]] = [] |
|
|
if system_prompt.strip(): |
|
|
msgs.append({"role": "system", "content": system_prompt.strip()}) |
|
|
for model_user, human_assistant in history_pairs: |
|
|
if model_user: |
|
|
msgs.append({"role": "user", "content": model_user}) |
|
|
if human_assistant: |
|
|
msgs.append({"role": "assistant", "content": human_assistant}) |
|
|
return msgs |
|
|
|
|
|
|
|
|
def pairs_to_ui_messages( |
|
|
history_pairs: List[Tuple[str, Optional[str]]] |
|
|
) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Convert (model_user, human_assistant) pairs to Gradio Chatbot(type='messages') UI messages. |
|
|
Visual convention: |
|
|
- LEFT (role='assistant'): UserLM's utterances (the simulator) |
|
|
- RIGHT (role='user'): Your replies (you play the assistant) |
|
|
""" |
|
|
ui: List[Dict[str, str]] = [] |
|
|
for model_user, human_assistant in history_pairs: |
|
|
if model_user: |
|
|
ui.append({"role": "assistant", "content": model_user}) |
|
|
if human_assistant: |
|
|
ui.append({"role": "user", "content": human_assistant}) |
|
|
return ui |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate_reply( |
|
|
system_prompt: str, |
|
|
history_pairs: List[Tuple[str, Optional[str]]], |
|
|
max_new_tokens: int = 128, |
|
|
temperature: float = 1.0, |
|
|
top_p: float = 0.8, |
|
|
max_retries: int = 10, |
|
|
) -> str: |
|
|
"""Implements the 4 guardrails from Appendix C.1 and passes an explicit attention_mask.""" |
|
|
messages = build_hf_messages(system_prompt, history_pairs) |
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, return_tensors="pt", add_generation_prompt=True |
|
|
).to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id |
|
|
if pad_id is not None and (inputs == pad_id).any(): |
|
|
attention_mask = (inputs != pad_id).long() |
|
|
else: |
|
|
attention_mask = torch.ones_like(inputs, dtype=torch.long) |
|
|
|
|
|
for _ in range(max_retries): |
|
|
lp = LogitsProcessorList( |
|
|
[ForbidFirstToken(FIRST_TOKEN_FILTER_IDS, prompt_len=inputs.shape[1])] |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
input_ids=inputs, |
|
|
attention_mask=attention_mask, |
|
|
do_sample=True, |
|
|
top_p=top_p, |
|
|
temperature=temperature, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=EOS_TOKEN_ID, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
bad_words_ids=BAD_WORDS_IDS, |
|
|
logits_processor=lp, |
|
|
) |
|
|
|
|
|
gen = out[0][inputs.shape[1]:] |
|
|
text = tokenizer.decode(gen, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
if not is_valid_length(text, min_words=3, max_words=25): |
|
|
continue |
|
|
if is_verbatim_repetition(text, history_pairs, system_prompt): |
|
|
continue |
|
|
return text |
|
|
|
|
|
raise RuntimeError("Failed to generate a valid user utterance after retries.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def respond( |
|
|
your_reply: str, |
|
|
history_pairs: List[Tuple[str, Optional[str]]], |
|
|
system_prompt: str, |
|
|
max_new_tokens: int, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
): |
|
|
|
|
|
if not history_pairs: |
|
|
userlm = generate_reply( |
|
|
system_prompt, |
|
|
[], |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
history_pairs = [(userlm, None)] |
|
|
return pairs_to_ui_messages(history_pairs), history_pairs, "" |
|
|
|
|
|
|
|
|
if not your_reply.strip(): |
|
|
gr.Info("Type your (assistant) reply on the right, then click Generate.") |
|
|
return pairs_to_ui_messages(history_pairs), history_pairs, "" |
|
|
|
|
|
|
|
|
last_userlm, _ = history_pairs[-1] |
|
|
history_pairs[-1] = (last_userlm, your_reply.strip()) |
|
|
|
|
|
|
|
|
userlm = generate_reply( |
|
|
system_prompt, |
|
|
history_pairs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
history_pairs.append((userlm, None)) |
|
|
|
|
|
return pairs_to_ui_messages(history_pairs), history_pairs, "" |
|
|
|
|
|
|
|
|
def _clear(): |
|
|
return [], [], DEFAULT_SYSTEM_PROMPT, "" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
f""" |
|
|
# UserLM-8b: User Language Model Demo |
|
|
**Model:** `{MODEL_ID}` |
|
|
|
|
|
The AI plays the **user**, you play the **assistant**. Your messages appear on the **right**. |
|
|
""" |
|
|
) |
|
|
|
|
|
system_box = gr.Textbox( |
|
|
label="User Intent", |
|
|
value=DEFAULT_SYSTEM_PROMPT, |
|
|
lines=3, |
|
|
placeholder="Enter the user's goal or intent", |
|
|
) |
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
label="Conversation", |
|
|
height=420, |
|
|
type="messages", |
|
|
render_markdown=True, |
|
|
autoscroll=True, |
|
|
show_copy_button=True, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
msg = gr.Textbox( |
|
|
label="Your Reply (assistant)", |
|
|
placeholder="Type your assistant response here…", |
|
|
info="Leave blank & press _Generate_ to create the **first user message**.", |
|
|
lines=2, |
|
|
) |
|
|
|
|
|
with gr.Accordion("Generation Settings", open=False): |
|
|
max_new_tokens = gr.Slider(16, 512, value=128, step=16, label="max_new_tokens") |
|
|
temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="temperature") |
|
|
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Generate", variant="primary") |
|
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
|
|
|
history_pairs_state = gr.State([]) |
|
|
|
|
|
with gr.Accordion("Implementation Details", open=False): |
|
|
gr.Markdown( |
|
|
""" |
|
|
- Decoding defaults from [the model card](https://hf.co/microsoft/UserLM-8b): `temperature=1.0`, `top_p=0.8`, stop on `<|eot_id|>`, and block `<|endconversation|>`. |
|
|
- Guardrails from Appendix C.1 [of the paper](https://arxiv.org/abs/2510.06552): (1) first-token logit filter, (2) block endconversation, (3) 3–25 word length, (4) verbatim repetition filter. |
|
|
""" |
|
|
) |
|
|
|
|
|
def _submit(your_text, pairs, sys_prompt, mnt, temp, tp): |
|
|
ui_msgs, new_pairs, cleared_text = respond( |
|
|
your_text, pairs, sys_prompt, mnt, temp, tp |
|
|
) |
|
|
return ui_msgs, new_pairs, cleared_text |
|
|
|
|
|
submit_btn.click( |
|
|
fn=_submit, |
|
|
inputs=[ |
|
|
msg, |
|
|
history_pairs_state, |
|
|
system_box, |
|
|
max_new_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
], |
|
|
outputs=[chatbot, history_pairs_state, msg], |
|
|
) |
|
|
msg.submit( |
|
|
fn=_submit, |
|
|
inputs=[ |
|
|
msg, |
|
|
history_pairs_state, |
|
|
system_box, |
|
|
max_new_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
], |
|
|
outputs=[chatbot, history_pairs_state, msg], |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=_clear, |
|
|
outputs=[chatbot, history_pairs_state, system_box, msg], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|