File size: 6,477 Bytes
0050b72
 
 
 
 
 
 
 
9e65b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc0c2e4
0050b72
bc0c2e4
 
 
0050b72
 
 
 
 
 
 
 
 
 
 
9e65b63
0050b72
 
9e65b63
0050b72
 
 
 
9e65b63
0050b72
 
 
 
 
 
bc0c2e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e65b63
bc0c2e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e65b63
bc0c2e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0050b72
 
 
 
 
 
 
 
 
 
 
 
 
 
c90d7c6
 
 
0050b72
 
 
 
 
c90d7c6
 
 
 
 
 
 
 
0050b72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""General helper utilities for authentication, policy loading, and file operations."""

import os

import gradio as gr
from dotenv import load_dotenv


def _load_token_from_env(env_var: str) -> str | None:
    """Load token from .env file."""
    load_dotenv()
    return os.getenv(env_var)


def get_label_emoji(policy_violation: int) -> str:
    """Get emoji for policy violation label."""
    return "❌" if policy_violation == 1 else "✅" if policy_violation == 0 else "⚠️"


def format_dataset_help_text(has_personal: bool, has_org: bool) -> str:
    """Format help text explaining dataset availability."""
    return (
        f"*Private Dataset: {'✅ Available' if has_personal else '❌ Requires personal token (OAuth login or .env)'}*\n"
        f"*ROOST Dataset: {'✅ Available' if has_org else '⚠️ Can load if public, requires org token to save'}*"
    )


def get_personal_token(oauth_token: gr.OAuthToken | None) -> tuple[str | None, str]:
    """
    Get personal Hugging Face token from OAuth or .env fallback.
    
    Used for personal/user operations like saving to private datasets.

    Args:
        oauth_token: Gradio OAuth token from user login, or None

    Returns:
        Tuple of (hf_token, status_message)
        - hf_token: Token string if available, None otherwise
        - status_message: Warning message if using local .env, empty string otherwise
    """
    if oauth_token is None or (isinstance(oauth_token, str) and oauth_token == "Log in to Hugging Face"):
        # Try loading from .env file
        hf_token = _load_token_from_env("HF_TOKEN_MLSOC")
        if hf_token is None:
            return None, ""
        return hf_token, "\n⚠️ Using local .env file for token (not online)"
    else:
        # OAuthToken object
        token = oauth_token.token
        if not token or not token.strip():
            hf_token = _load_token_from_env("HF_TOKEN_MLSOC")
            if hf_token:
                return hf_token, "\n⚠️ Using local .env file for token (not online)"
            return None, ""
        return token, ""


def get_org_token() -> str | None:
    """
    Get organization token from Space secret or .env fallback.
    
    Used for ROOST org dataset operations and inference (preferred).
    
    Returns:
        Token string if available, None otherwise
    """
    # Check Space secret HACKATHON_INFERENCE_TOKEN
    org_token = os.getenv("HACKATHON_INFERENCE_TOKEN")
    if org_token:
        return org_token
    
    # Fall back to .env file
    return _load_token_from_env("ROOST_TOKEN_FALLBACK")


def get_inference_token(oauth_token: gr.OAuthToken | None) -> tuple[str | None, str]:
    """
    Get token for inference (org token preferred, falls back to personal).
    
    Returns:
        Tuple of (token, status_message)
    """
    # Try org token first
    org_token = get_org_token()
    if org_token:
        return org_token, ""
    
    # Fall back to personal token
    personal_token, status_msg = get_personal_token(oauth_token)
    return personal_token, status_msg


def check_token_availability(oauth_token: gr.OAuthToken | None) -> tuple[bool, bool]:
    """
    Check which tokens are available.
    
    Returns:
        Tuple of (has_personal: bool, has_org: bool)
    """
    has_personal = get_personal_token(oauth_token)[0] is not None
    has_org = get_org_token() is not None
    return has_personal, has_org


def format_token_status(oauth_token: gr.OAuthToken | None) -> str:
    """
    Format markdown showing token status and usage.
    
    Returns:
        Markdown string explaining which tokens are set and their uses
    """
    has_personal, has_org = check_token_availability(oauth_token)
    
    lines = [
        "You can log in to your Hugging Face account to save your work in a private dataset and use the app for inference after the end of the hackathon.",
        "### Token Status",
    ]
    
    # Personal token status
    if has_personal:
        personal_token, status_msg = get_personal_token(oauth_token)
        if oauth_token and oauth_token.token:
            source = "OAuth login"
        else:
            source = ".env file"
        lines.append(f"- **Personal Token**: ✅ Available ({source})")
        lines.append("  - Enables: Inference (fallback), Private dataset saves/loads")
    else:
        lines.append("- **Personal Token**: ❌ Not available")
        lines.append("  - Required for: Private dataset operations")
    
    # Org token status
    if has_org:
        org_token = get_org_token()
        # Check if it's from Space secret or .env
        if os.getenv("HACKATHON_INFERENCE_TOKEN"):
            source = "Space secret"
        else:
            source = ".env file"
        lines.append(f"- **Org Token**: ✅ Available ({source})")
        lines.append("  - Enables: Inference (preferred), ROOST dataset saves/loads")
    else:
        lines.append("- **Org Token**: ❌ Not available")
        lines.append("  - Required for: ROOST dataset saves")
        lines.append("  - Note: ROOST dataset can be loaded if public")
    
    return "\n".join(lines)


def load_preset_policy(preset_name: str, base_dir: str) -> tuple[str, str]:
    """Load preset policy from markdown file."""
    preset_files = {
        "Hate Speech Policy": "hate_speech.md",
        "Violence Policy": "violence.md",
        "Toxicity Policy": "toxicity.md",
    }
    if preset_name in preset_files:
        policy_path = os.path.join(base_dir, "example_policies", preset_files[preset_name])
        try:
            with open(policy_path, "r") as f:
                policy_text = f.read()
            return policy_text, policy_text
        except FileNotFoundError:
            raise gr.Error(f"Policy file '{preset_files[preset_name]}' not found at {policy_path}. Please check the file exists.")
        except Exception as e:
            raise gr.Error(f"Failed to load policy file '{preset_files[preset_name]}': {str(e)}")
    return "", ""


def load_policy_from_file(file_path: str) -> tuple[str, str]:
    """Load policy from uploaded file."""
    try:
        with open(file_path, "r") as f:
            content = f.read()
        return content, content
    except FileNotFoundError:
        raise gr.Error(f"File not found: {file_path}. Please try uploading the file again.")
    except Exception as e:
        raise gr.Error(f"Failed to read policy file: {str(e)}. Please check the file format and try again.")