File size: 6,159 Bytes
0050b72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e49a6
0050b72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c90d7c6
 
0050b72
 
c90d7c6
0050b72
 
 
 
c90d7c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0050b72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Model interface for calling moderation models."""

import json
import re

from openai import OpenAI
from openai_harmony import (
    DeveloperContent,
    HarmonyEncodingName,
    Message,
    Role,
    SystemContent,
    load_harmony_encoding,
)

from utils.constants import (
    DEFAULT_MAX_TOKENS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_P,
    GENERIC_SYSTEM_PROMPT_PREFIX,
    LM_PROMPT_INSTRUCT,
    RESPONSE_FORMAT,
    ROUTER_URL,
    MODELS,
)


def get_model_info(model_id: str) -> dict:
    """Get model metadata by ID."""
    for model in MODELS:
        if model["id"] == model_id:
            return model
    return None


def extract_model_id(choice: str) -> str:
    """Extract model ID from dropdown choice format 'Name (id)'."""
    if not choice:
        return ""
    return choice.split("(")[-1].rstrip(")")


def is_gptoss_model(model_id: str) -> bool:
    """Check if model is GPT-OSS."""
    return model_id.startswith("openai/gpt-oss")


def get_default_system_prompt(model_id: str, reasoning_effort: str = "Low") -> str:
    """Generate default system prompt based on model type and policy."""
    if is_gptoss_model(model_id):
        enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        system_prompt_harmony = Message.from_role_and_content(
            Role.SYSTEM, SystemContent.new().with_reasoning_effort(reasoning_effort)
        )
        system_prompt_dict = enc.decode(enc.render(system_prompt_harmony))
        system_prompt_content = re.search(r"<\|message\|>(.*?)<\|end\|>", system_prompt_dict, re.DOTALL).group(1)
        return system_prompt_content
    else:
        # Qwen: formatted system prompt (goes in system role)
        return GENERIC_SYSTEM_PROMPT_PREFIX


def make_messages(test: str, policy: str, model_id: str, reasoning_effort: str = "Low", system_prompt: str | None = None, response_format: str = RESPONSE_FORMAT) -> list[dict]:
    """Create messages based on model type."""
    if model_id.startswith("openai/gpt-oss-safeguard"):
        # GPT-OSS uses Harmony encoding
        enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        system_content = SystemContent.new().with_reasoning_effort(reasoning_effort)
        conv_messages = [
            Message.from_role_and_content(
                Role.DEVELOPER,
                DeveloperContent.new().with_instructions(policy + "\n\n" + response_format),
            ),
            Message.from_role_and_content(Role.USER, test),
        ]
        messages = [
            {"role": "system", "content": system_prompt},
        ]
        for pre_msg in conv_messages:
            tokens = enc.render(pre_msg)
            prompt = enc.decode(tokens)
            messages.append({
                "role": re.search(r"<\|start\|>(.*?)<\|message\|>", prompt).group(1),
                "content": re.search(r"<\|message\|>(.*?)<\|end\|>", prompt, re.DOTALL).group(1),
            })
        return messages
    else:
        system_content = LM_PROMPT_INSTRUCT.format(
            system_prompt=system_prompt,
            policy=policy,
            response_format=response_format
        )
        return [
            {"role": "system", "content": system_content},
            {"role": "user", "content": f"Content: {test}\n\nResponse:"},
        ]


def run_test(
    model_id: str,
    test_input: str,
    policy: str,
    hf_token: str,
    reasoning_effort: str = "Low",
    max_tokens: int = DEFAULT_MAX_TOKENS,
    temperature: float = DEFAULT_TEMPERATURE,
    top_p: float = DEFAULT_TOP_P,
    system_prompt: str | None = None,
    response_format: str = RESPONSE_FORMAT,
) -> dict:
    """Run test on model."""
    import gradio as gr
    
    model_info = get_model_info(model_id)
    if not model_info:
        raise gr.Error(f"Unknown model: {model_id}. Please select a valid model from the dropdown.")

    client = OpenAI(base_url=ROUTER_URL, api_key=hf_token)
    messages = make_messages(test_input, policy, model_id, reasoning_effort, system_prompt, response_format)

    try:
        completion = client.chat.completions.create(
            model=model_id,
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=None,
            extra_headers={"X-HF-Bill-To": "roosttools"},
        )
    except Exception as e:
        error_msg = str(e)
        if "401" in error_msg or "authentication" in error_msg.lower():
            raise gr.Error(f"Authentication failed: {error_msg}. Please check your token permissions.")
        elif "400" in error_msg or "bad request" in error_msg.lower():
            raise gr.Error(f"Invalid request: {error_msg}. Please check your input and try again.")
        elif "429" in error_msg or "rate limit" in error_msg.lower():
            raise gr.Error(f"Rate limit exceeded: {error_msg}. Please wait a moment and try again.")
        elif "timeout" in error_msg.lower():
            raise gr.Error(f"Request timed out: {error_msg}. Please try again with a shorter input or lower max_tokens.")
        else:
            raise gr.Error(f"Model inference failed: {error_msg}. Please check your inputs and try again.")

    result = {"content": completion.choices[0].message.content}

    # Extract reasoning if available
    message = completion.choices[0].message
    if model_info["is_thinking"]:
        if is_gptoss_model(model_id):
            # GPT-OSS: check reasoning or reasoning_content field
            if hasattr(message, "reasoning") and message.reasoning:
                result["reasoning"] = message.reasoning
            elif hasattr(message, "reasoning_content") and message.reasoning_content:
                result["reasoning"] = message.reasoning_content
        else:
            # Qwen Thinking: extract from content using </think> tag
            content = message.content
            if "</think>" in content:
                result["reasoning"] = content.split("</think>")[0].strip()
                # Also update content to be the part after </think>
                result["content"] = content.split("</think>")[-1].strip()

    return result