Spaces:
Running
Running
File size: 6,159 Bytes
0050b72 39e49a6 0050b72 c90d7c6 0050b72 c90d7c6 0050b72 c90d7c6 0050b72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""Model interface for calling moderation models."""
import json
import re
from openai import OpenAI
from openai_harmony import (
DeveloperContent,
HarmonyEncodingName,
Message,
Role,
SystemContent,
load_harmony_encoding,
)
from utils.constants import (
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
GENERIC_SYSTEM_PROMPT_PREFIX,
LM_PROMPT_INSTRUCT,
RESPONSE_FORMAT,
ROUTER_URL,
MODELS,
)
def get_model_info(model_id: str) -> dict:
"""Get model metadata by ID."""
for model in MODELS:
if model["id"] == model_id:
return model
return None
def extract_model_id(choice: str) -> str:
"""Extract model ID from dropdown choice format 'Name (id)'."""
if not choice:
return ""
return choice.split("(")[-1].rstrip(")")
def is_gptoss_model(model_id: str) -> bool:
"""Check if model is GPT-OSS."""
return model_id.startswith("openai/gpt-oss")
def get_default_system_prompt(model_id: str, reasoning_effort: str = "Low") -> str:
"""Generate default system prompt based on model type and policy."""
if is_gptoss_model(model_id):
enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
system_prompt_harmony = Message.from_role_and_content(
Role.SYSTEM, SystemContent.new().with_reasoning_effort(reasoning_effort)
)
system_prompt_dict = enc.decode(enc.render(system_prompt_harmony))
system_prompt_content = re.search(r"<\|message\|>(.*?)<\|end\|>", system_prompt_dict, re.DOTALL).group(1)
return system_prompt_content
else:
# Qwen: formatted system prompt (goes in system role)
return GENERIC_SYSTEM_PROMPT_PREFIX
def make_messages(test: str, policy: str, model_id: str, reasoning_effort: str = "Low", system_prompt: str | None = None, response_format: str = RESPONSE_FORMAT) -> list[dict]:
"""Create messages based on model type."""
if model_id.startswith("openai/gpt-oss-safeguard"):
# GPT-OSS uses Harmony encoding
enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
system_content = SystemContent.new().with_reasoning_effort(reasoning_effort)
conv_messages = [
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions(policy + "\n\n" + response_format),
),
Message.from_role_and_content(Role.USER, test),
]
messages = [
{"role": "system", "content": system_prompt},
]
for pre_msg in conv_messages:
tokens = enc.render(pre_msg)
prompt = enc.decode(tokens)
messages.append({
"role": re.search(r"<\|start\|>(.*?)<\|message\|>", prompt).group(1),
"content": re.search(r"<\|message\|>(.*?)<\|end\|>", prompt, re.DOTALL).group(1),
})
return messages
else:
system_content = LM_PROMPT_INSTRUCT.format(
system_prompt=system_prompt,
policy=policy,
response_format=response_format
)
return [
{"role": "system", "content": system_content},
{"role": "user", "content": f"Content: {test}\n\nResponse:"},
]
def run_test(
model_id: str,
test_input: str,
policy: str,
hf_token: str,
reasoning_effort: str = "Low",
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
system_prompt: str | None = None,
response_format: str = RESPONSE_FORMAT,
) -> dict:
"""Run test on model."""
import gradio as gr
model_info = get_model_info(model_id)
if not model_info:
raise gr.Error(f"Unknown model: {model_id}. Please select a valid model from the dropdown.")
client = OpenAI(base_url=ROUTER_URL, api_key=hf_token)
messages = make_messages(test_input, policy, model_id, reasoning_effort, system_prompt, response_format)
try:
completion = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=None,
extra_headers={"X-HF-Bill-To": "roosttools"},
)
except Exception as e:
error_msg = str(e)
if "401" in error_msg or "authentication" in error_msg.lower():
raise gr.Error(f"Authentication failed: {error_msg}. Please check your token permissions.")
elif "400" in error_msg or "bad request" in error_msg.lower():
raise gr.Error(f"Invalid request: {error_msg}. Please check your input and try again.")
elif "429" in error_msg or "rate limit" in error_msg.lower():
raise gr.Error(f"Rate limit exceeded: {error_msg}. Please wait a moment and try again.")
elif "timeout" in error_msg.lower():
raise gr.Error(f"Request timed out: {error_msg}. Please try again with a shorter input or lower max_tokens.")
else:
raise gr.Error(f"Model inference failed: {error_msg}. Please check your inputs and try again.")
result = {"content": completion.choices[0].message.content}
# Extract reasoning if available
message = completion.choices[0].message
if model_info["is_thinking"]:
if is_gptoss_model(model_id):
# GPT-OSS: check reasoning or reasoning_content field
if hasattr(message, "reasoning") and message.reasoning:
result["reasoning"] = message.reasoning
elif hasattr(message, "reasoning_content") and message.reasoning_content:
result["reasoning"] = message.reasoning_content
else:
# Qwen Thinking: extract from content using </think> tag
content = message.content
if "</think>" in content:
result["reasoning"] = content.split("</think>")[0].strip()
# Also update content to be the part after </think>
result["content"] = content.split("</think>")[-1].strip()
return result
|