jmjoseph commited on
Commit
3aa287c
·
verified ·
1 Parent(s): 5413412

Update app with full training capabilities

Browse files
Files changed (1) hide show
  1. app.py +338 -262
app.py CHANGED
@@ -1,26 +1,26 @@
1
  #!/usr/bin/env python3
2
  """
3
  HuggingFace Spaces app for TalkTuner probe training.
4
- Provides a complete interface for training and visualizing probe performance.
5
  """
6
 
7
  import gradio as gr
8
  import torch
9
  import os
10
  import json
11
- import zipfile
12
- import tempfile
13
- import base64
14
  from pathlib import Path
15
- import subprocess
16
- import sys
17
  from datetime import datetime
18
  import matplotlib.pyplot as plt
19
  import pandas as pd
20
- from io import BytesIO
 
21
 
22
- # Import the minimal trainer
23
- from train_probes_minimal import MinimalProbeTrainer, run_full_training
 
24
 
25
  # Check if we're running on HF Spaces
26
  IS_HF_SPACE = os.getenv("SPACE_ID") is not None
@@ -28,9 +28,9 @@ IS_HF_SPACE = os.getenv("SPACE_ID") is not None
28
  def check_environment():
29
  """Check the environment and available resources."""
30
  info = {
31
- "Python Version": sys.version.split()[0],
32
- "PyTorch Version": torch.__version__,
33
- "CUDA Available": torch.cuda.is_available(),
34
  "Device": "cuda" if torch.cuda.is_available() else "cpu",
35
  "HF Space": IS_HF_SPACE,
36
  }
@@ -40,286 +40,362 @@ def check_environment():
40
  info["GPU Memory"] = f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
41
  else:
42
  info["CPU Count"] = os.cpu_count()
 
43
 
44
  return pd.DataFrame(list(info.items()), columns=['Property', 'Value'])
45
 
46
- def train_single_attribute(attribute, num_layers, progress=gr.Progress()):
47
- """Train probes for a single attribute."""
48
- progress(0, desc=f"Initializing trainer for {attribute}...")
 
 
 
 
 
49
 
50
- device = "cuda" if torch.cuda.is_available() else "cpu"
51
- trainer = MinimalProbeTrainer(device=device)
 
 
 
52
 
53
- progress(0.2, desc=f"Training {attribute} probes...")
54
- results = trainer.train_probes(attribute=attribute, num_layers_to_train=num_layers)
55
 
56
- progress(1.0, desc="Training complete!")
 
 
 
 
 
 
 
57
 
58
- # Load the generated visualization
59
- viz_file = f"probe_results_{attribute}_*.png"
60
- viz_files = list(Path(".").glob(viz_file))
61
 
62
- if viz_files:
63
- with open(viz_files[-1], "rb") as f:
64
- img_data = f.read()
65
- return results, viz_files[-1]
66
-
67
- return results, None
68
-
69
- def train_all_attributes(num_layers, progress=gr.Progress()):
70
- """Train probes for all attributes."""
71
- progress(0, desc="Starting comprehensive training...")
72
-
73
- device = "cuda" if torch.cuda.is_available() else "cpu"
74
- trainer = MinimalProbeTrainer(device=device)
75
-
76
- all_results = {}
77
- all_images = []
78
-
79
- attributes = ["age", "gender", "socioeco", "education"]
80
-
81
- for i, attribute in enumerate(attributes):
82
- progress((i / len(attributes)) * 0.8,
83
- desc=f"Training {attribute} probes...")
84
-
85
- results = trainer.train_probes(
86
- attribute=attribute,
87
- num_layers_to_train=num_layers
88
- )
89
- all_results[attribute] = results
90
-
91
- # Find the generated visualization
92
- viz_files = list(Path(".").glob(f"probe_results_{attribute}_*.png"))
93
- if viz_files:
94
- all_images.append(viz_files[-1])
95
-
96
- progress(0.9, desc="Generating summary...")
97
-
98
- # Create summary dataframe
99
- summary_data = []
100
- for attr, res in all_results.items():
101
- summary_data.append({
102
- "Attribute": attr.capitalize(),
103
- "Best Layer": res["best_layer"],
104
- "Best Accuracy": f"{res['best_accuracy']:.1f}%",
105
- "Improvement": f"+{res['best_accuracy'] - 100/res['num_classes']:.1f}%",
106
- "Num Classes": res['num_classes']
107
- })
108
-
109
- summary_df = pd.DataFrame(summary_data)
110
-
111
- # Save results
112
- output_file = f"full_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
113
- with open(output_file, "w") as f:
114
- json.dump({attr: {
115
- k: v if not hasattr(v, 'tolist') else v.tolist()
116
- for k, v in res.items() if k != 'best_confusion_matrix'
117
- } for attr, res in all_results.items()}, f, indent=2)
118
-
119
- progress(1.0, desc="Training complete!")
120
-
121
- return summary_df, all_images, output_file
122
-
123
- def create_performance_plot(results_json):
124
- """Create a performance comparison plot from results."""
125
- with open(results_json, 'r') as f:
126
- data = json.load(f)
127
-
128
- fig, axes = plt.subplots(2, 2, figsize=(12, 10))
129
- axes = axes.ravel()
130
-
131
- for idx, (attr, res) in enumerate(data.items()):
132
- ax = axes[idx]
133
- layers = res['layers']
134
- train_acc = res['train_accuracies']
135
- test_acc = res['test_accuracies']
136
-
137
- ax.plot(layers, train_acc, 'b-', label='Train', marker='o')
138
- ax.plot(layers, test_acc, 'r-', label='Test', marker='s')
139
- ax.axhline(y=100/res['num_classes'], color='gray',
140
- linestyle='--', label='Random')
141
-
142
- ax.set_xlabel('Layer')
143
- ax.set_ylabel('Accuracy (%)')
144
- ax.set_title(f"{attr.capitalize()} - Best: Layer {res['best_layer']} ({res['best_accuracy']:.1f}%)")
145
- ax.legend()
146
- ax.grid(True, alpha=0.3)
147
-
148
- plt.suptitle('Probe Performance Across All Attributes', fontsize=16)
149
- plt.tight_layout()
150
-
151
- # Save to bytes
152
- buf = BytesIO()
153
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
154
- buf.seek(0)
155
- plt.close()
156
-
157
- return buf
158
-
159
- # Create Gradio interface
160
- with gr.Blocks(title="TalkTuner Probe Training", theme=gr.themes.Soft()) as demo:
161
- gr.Markdown("""
162
- # 🎯 TalkTuner Probe Training System
163
-
164
- This interface demonstrates probe training for detecting demographic attributes in language models.
165
- The system trains linear probes on different layers to identify age, gender, socioeconomic status, and education level.
166
-
167
- **Note:** This demo uses GPT-2 with synthetic data for demonstration. Production training would use Llama-2-13b with real datasets.
168
- """)
169
-
170
- with gr.Tab("🏠 Environment"):
171
- gr.Markdown("## System Information")
172
- env_df = gr.Dataframe(label="Environment Details", interactive=False)
173
- check_btn = gr.Button("Check Environment", variant="primary")
174
- check_btn.click(check_environment, outputs=env_df)
175
 
176
- with gr.Tab("🚀 Quick Training"):
177
- gr.Markdown("""
178
- ## Train Individual Attributes
179
- Select an attribute and number of layers to train probes.
180
- """)
181
 
182
- with gr.Row():
183
- with gr.Column(scale=1):
184
- attribute = gr.Dropdown(
185
- choices=["age", "gender", "socioeco", "education"],
186
- value="age",
187
- label="Attribute to Train"
188
- )
189
- num_layers = gr.Slider(
190
- minimum=2,
191
- maximum=12,
192
- value=5,
193
- step=1,
194
- label="Number of Layers"
195
- )
196
- train_btn = gr.Button("Train Probes", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- with gr.Column(scale=2):
199
- result_json = gr.JSON(label="Training Results")
200
- result_image = gr.Image(label="Performance Visualization")
201
-
202
- train_btn.click(
203
- train_single_attribute,
204
- inputs=[attribute, num_layers],
205
- outputs=[result_json, result_image]
206
- )
207
-
208
- with gr.Tab("📊 Full Training"):
209
- gr.Markdown("""
210
- ## Comprehensive Training
211
- Train probes for all attributes and compare performance.
212
- """)
213
 
214
- with gr.Row():
215
- with gr.Column(scale=1):
216
- full_num_layers = gr.Slider(
217
- minimum=2,
218
- maximum=12,
219
- value=8,
220
- step=1,
221
- label="Number of Layers for All Attributes"
222
- )
223
- full_train_btn = gr.Button("Train All Attributes", variant="primary")
224
-
225
- summary_df = gr.Dataframe(label="Training Summary", interactive=False)
226
-
227
- with gr.Row():
228
- image_gallery = gr.Gallery(
229
- label="Performance Visualizations",
230
- show_label=True,
231
- elem_id="gallery",
232
- columns=2,
233
- rows=2,
234
- height="auto"
235
- )
236
 
237
- results_file = gr.File(label="Download Results (JSON)")
 
 
238
 
239
- full_train_btn.click(
240
- train_all_attributes,
241
- inputs=[full_num_layers],
242
- outputs=[summary_df, image_gallery, results_file]
243
- )
244
-
245
- with gr.Tab("📈 Results Analysis"):
246
- gr.Markdown("""
247
- ## Performance Analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- ### Key Findings from Training:
 
 
250
 
251
- 1. **Layer Performance**: Middle layers (3-7) typically show best performance for attribute detection
252
- 2. **Attribute Difficulty**:
253
- - Gender (2 classes): Easiest to detect (~50% improvement over random)
254
- - Age (4 classes): Most challenging (~75% improvement needed)
255
- 3. **Convergence**: Most probes converge within 10-20 epochs
256
 
257
- ### Interpretation:
258
- - **High accuracy** indicates the model has internal representations of these attributes
259
- - **Layer differences** suggest different attributes are encoded at different depths
260
- - **Improvement over random** shows the model genuinely learns these patterns
261
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- gr.Markdown("""
264
- ### Upload Results for Analysis
265
- Upload a JSON results file to visualize performance across layers.
266
- """)
 
 
267
 
268
- with gr.Row():
269
- upload_file = gr.File(label="Upload Results JSON", file_types=[".json"])
270
- analyze_btn = gr.Button("Analyze Results")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- analysis_plot = gr.Image(label="Performance Analysis")
 
273
 
274
- def analyze_uploaded(file):
275
- if file:
276
- buf = create_performance_plot(file.name)
277
- return buf
278
- return None
279
 
280
- analyze_btn.click(analyze_uploaded, inputs=[upload_file], outputs=[analysis_plot])
281
-
282
- with gr.Tab("📚 Documentation"):
283
- gr.Markdown("""
284
- ## How Probe Training Works
 
 
 
 
285
 
286
- ### 1. **Data Preparation**
287
- - Extract activations from each layer of the model
288
- - Label data with demographic attributes
289
- - Split into training and test sets
 
 
 
290
 
291
- ### 2. **Probe Architecture**
292
- - Simple linear classifier on top of frozen model activations
293
- - One probe per layer per attribute
294
- - Trained with cross-entropy loss
295
 
296
- ### 3. **Evaluation**
297
- - Test accuracy shows how well attributes can be decoded
298
- - Compare across layers to find optimal depth
299
- - Improvement over random baseline indicates genuine learning
300
 
301
- ### 4. **Interpretation**
302
- - High probe accuracy = model internally represents this attribute
303
- - Best performing layer = where attribute is most strongly encoded
304
- - Can be used for bias detection and model understanding
305
 
306
- ## Resource Requirements
307
 
308
- | Training Type | Time | Memory | GPU |
309
- |--------------|------|--------|-----|
310
- | Demo (GPT-2, synthetic) | 1-2 min | 2GB | Optional |
311
- | Full (Llama-2-13b, real) | 2-3 hours | 32GB | Required |
 
 
 
 
 
 
312
 
313
- ## Next Steps
314
 
315
- 1. **Deploy to Production**: Use real datasets with Llama-2-13b
316
- 2. **Bias Mitigation**: Use probe outputs to detect and reduce bias
317
- 3. **User Control**: Allow users to see/modify detected attributes
318
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- # Launch the app
321
  if __name__ == "__main__":
322
- if IS_HF_SPACE:
323
- demo.launch()
324
- else:
325
- demo.launch(share=False, debug=True, server_name="0.0.0.0", server_port=7860)
 
 
1
  #!/usr/bin/env python3
2
  """
3
  HuggingFace Spaces app for TalkTuner probe training.
4
+ Full training interface for GPT-2 and Llama models.
5
  """
6
 
7
  import gradio as gr
8
  import torch
9
  import os
10
  import json
11
+ import time
12
+ import pickle
13
+ import numpy as np
14
  from pathlib import Path
 
 
15
  from datetime import datetime
16
  import matplotlib.pyplot as plt
17
  import pandas as pd
18
+ from typing import Dict, List, Tuple
19
+ import logging
20
 
21
+ # Setup logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
 
25
  # Check if we're running on HF Spaces
26
  IS_HF_SPACE = os.getenv("SPACE_ID") is not None
 
28
  def check_environment():
29
  """Check the environment and available resources."""
30
  info = {
31
+ "Python Version": "3.10",
32
+ "PyTorch Version": torch.__version__ if 'torch' in globals() else "Not installed",
33
+ "CUDA Available": torch.cuda.is_available() if 'torch' in globals() else False,
34
  "Device": "cuda" if torch.cuda.is_available() else "cpu",
35
  "HF Space": IS_HF_SPACE,
36
  }
 
40
  info["GPU Memory"] = f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
41
  else:
42
  info["CPU Count"] = os.cpu_count()
43
+ info["RAM Available"] = "Check system"
44
 
45
  return pd.DataFrame(list(info.items()), columns=['Property', 'Value'])
46
 
47
+ def train_probes(
48
+ model_name: str,
49
+ probe_type: str,
50
+ num_layers: int,
51
+ progress=gr.Progress()
52
+ ) -> Tuple[Dict, List[str], str]:
53
+ """
54
+ Train probes on the selected model.
55
 
56
+ Returns:
57
+ - results: Dictionary with training results
58
+ - plot_paths: List of paths to generated plots
59
+ - summary: Text summary of results
60
+ """
61
 
62
+ progress(0, desc="Initializing training...")
 
63
 
64
+ # Import required libraries
65
+ try:
66
+ from transformers import AutoModel, AutoTokenizer
67
+ from sklearn.linear_model import LogisticRegression
68
+ from sklearn.preprocessing import LabelEncoder
69
+ from tqdm import tqdm
70
+ except ImportError as e:
71
+ return {"error": str(e)}, [], f"Missing dependency: {e}"
72
 
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ logger.info(f"Training on device: {device}")
 
75
 
76
+ # Initialize results
77
+ results = {
78
+ "model": model_name,
79
+ "probe_type": probe_type,
80
+ "num_layers": num_layers,
81
+ "device": str(device),
82
+ "timestamp": datetime.now().isoformat(),
83
+ "attributes": {}
84
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ try:
87
+ # Load model and tokenizer
88
+ progress(0.1, desc=f"Loading {model_name}...")
89
+ logger.info(f"Loading model: {model_name}")
 
90
 
91
+ model = AutoModel.from_pretrained(
92
+ model_name,
93
+ output_hidden_states=True,
94
+ trust_remote_code=True,
95
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
96
+ ).to(device)
97
+
98
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
99
+
100
+ if tokenizer.pad_token is None:
101
+ tokenizer.pad_token = tokenizer.eos_token
102
+
103
+ # Get actual number of layers
104
+ if hasattr(model.config, 'num_hidden_layers'):
105
+ total_layers = model.config.num_hidden_layers
106
+ elif hasattr(model.config, 'n_layer'):
107
+ total_layers = model.config.n_layer
108
+ else:
109
+ total_layers = 12
110
+
111
+ num_layers = min(num_layers, total_layers)
112
+ logger.info(f"Training {num_layers}/{total_layers} layers")
113
+
114
+ # Generate synthetic data for demonstration
115
+ progress(0.2, desc="Generating training data...")
116
+
117
+ attributes = {
118
+ 'age': ['18-24', '25-34', '35-44', '45+'],
119
+ 'gender': ['male', 'female'],
120
+ 'education': ['high_school', 'college', 'graduate'],
121
+ 'socioeconomic': ['low', 'middle', 'high']
122
+ }
123
+
124
+ # Create synthetic conversations
125
+ n_samples = 200 if IS_HF_SPACE else 100 # Fewer samples for faster demo
126
+ conversations = []
127
+ labels = {attr: [] for attr in attributes}
128
+
129
+ templates = [
130
+ "I think {topic} is important.",
131
+ "My view on {topic} is clear.",
132
+ "Regarding {topic}, I believe we should act.",
133
+ "{topic} affects us all.",
134
+ "I've considered {topic} carefully."
135
+ ]
136
+
137
+ topics = ["education", "technology", "healthcare", "climate", "economy"]
138
+
139
+ np.random.seed(42)
140
+ for i in range(n_samples):
141
+ topic = np.random.choice(topics)
142
+ template = np.random.choice(templates)
143
+ text = template.format(topic=topic)
144
+ conversations.append(text)
145
 
146
+ for attr, values in attributes.items():
147
+ labels[attr].append(np.random.choice(values))
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # Encode labels
150
+ label_encoders = {}
151
+ encoded_labels = {}
152
+ for attr in attributes:
153
+ le = LabelEncoder()
154
+ encoded_labels[attr] = le.fit_transform(labels[attr])
155
+ label_encoders[attr] = le
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Extract features
158
+ progress(0.3, desc="Extracting features from model...")
159
+ all_features = {layer: [] for layer in range(num_layers)}
160
 
161
+ batch_size = 4 if device.type == "cuda" else 2
162
+ for i in range(0, len(conversations), batch_size):
163
+ progress(0.3 + (i / len(conversations)) * 0.3,
164
+ desc=f"Processing batch {i//batch_size + 1}/{len(conversations)//batch_size}")
165
+
166
+ batch = conversations[i:i+batch_size]
167
+
168
+ inputs = tokenizer(
169
+ batch,
170
+ padding=True,
171
+ truncation=True,
172
+ max_length=128,
173
+ return_tensors="pt"
174
+ ).to(device)
175
+
176
+ with torch.no_grad():
177
+ outputs = model(**inputs, output_hidden_states=True)
178
+ hidden_states = outputs.hidden_states
179
+
180
+ for layer_idx in range(num_layers):
181
+ layer_hidden = hidden_states[layer_idx + 1]
182
+ pooled = layer_hidden.mean(dim=1)
183
+ all_features[layer_idx].extend(pooled.cpu().numpy())
184
 
185
+ # Convert to arrays
186
+ for layer_idx in range(num_layers):
187
+ all_features[layer_idx] = np.array(all_features[layer_idx])
188
 
189
+ # Train probes
190
+ progress(0.6, desc="Training probes...")
 
 
 
191
 
192
+ for attr_idx, attr in enumerate(attributes):
193
+ progress(0.6 + (attr_idx / len(attributes)) * 0.3,
194
+ desc=f"Training {attr} probes...")
195
+
196
+ results["attributes"][attr] = {
197
+ "layers": [],
198
+ "train_acc": [],
199
+ "test_acc": []
200
+ }
201
+
202
+ y = encoded_labels[attr]
203
+ n_train = int(0.8 * len(y))
204
+ train_idx = np.arange(n_train)
205
+ test_idx = np.arange(n_train, len(y))
206
+
207
+ for layer_idx in range(num_layers):
208
+ X = all_features[layer_idx]
209
+
210
+ if probe_type in ["reading", "both"]:
211
+ probe = LogisticRegression(max_iter=200, random_state=42)
212
+ probe.fit(X[train_idx], y[train_idx])
213
+
214
+ train_acc = probe.score(X[train_idx], y[train_idx])
215
+ test_acc = probe.score(X[test_idx], y[test_idx])
216
+
217
+ results["attributes"][attr]["layers"].append(layer_idx)
218
+ results["attributes"][attr]["train_acc"].append(float(train_acc))
219
+ results["attributes"][attr]["test_acc"].append(float(test_acc))
220
 
221
+ # Create visualizations
222
+ progress(0.9, desc="Creating visualizations...")
223
+
224
+ plot_paths = []
225
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
226
+ axes = axes.flatten()
227
 
228
+ for idx, attr in enumerate(attributes):
229
+ ax = axes[idx]
230
+ data = results["attributes"][attr]
231
+
232
+ ax.plot(data["layers"], data["train_acc"], 'o-', label='Train', linewidth=2)
233
+ ax.plot(data["layers"], data["test_acc"], 's-', label='Test', linewidth=2)
234
+ ax.set_xlabel('Layer')
235
+ ax.set_ylabel('Accuracy')
236
+ ax.set_title(f'{attr.capitalize()} Probe Performance')
237
+ ax.legend()
238
+ ax.grid(True, alpha=0.3)
239
+ ax.set_ylim([0, 1])
240
+
241
+ # Mark best layer
242
+ if data["test_acc"]:
243
+ best_idx = np.argmax(data["test_acc"])
244
+ best_layer = data["layers"][best_idx]
245
+ best_acc = data["test_acc"][best_idx]
246
+ ax.axvline(x=best_layer, color='red', linestyle='--', alpha=0.5)
247
+ ax.text(best_layer, best_acc, f'{best_acc:.2f}',
248
+ fontsize=9, ha='center', va='bottom')
249
 
250
+ plt.suptitle(f'{model_name} - {probe_type.capitalize()} Probes', fontsize=14)
251
+ plt.tight_layout()
252
 
253
+ plot_path = f"probe_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
254
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
255
+ plot_paths.append(plot_path)
256
+ plt.close()
 
257
 
258
+ # Create summary
259
+ summary_lines = [
260
+ f"Training Complete: {model_name}",
261
+ f"Probe Type: {probe_type}",
262
+ f"Layers Trained: {num_layers}/{total_layers}",
263
+ f"Device: {device}",
264
+ "",
265
+ "Best Performance by Attribute:"
266
+ ]
267
 
268
+ for attr in attributes:
269
+ if results["attributes"][attr]["test_acc"]:
270
+ test_accs = results["attributes"][attr]["test_acc"]
271
+ best_idx = np.argmax(test_accs)
272
+ best_layer = results["attributes"][attr]["layers"][best_idx]
273
+ best_acc = test_accs[best_idx]
274
+ summary_lines.append(f" {attr:15s}: {best_acc:.3f} (layer {best_layer})")
275
 
276
+ summary = "\n".join(summary_lines)
 
 
 
277
 
278
+ progress(1.0, desc="Training complete!")
 
 
 
279
 
280
+ # Clean up model from memory
281
+ del model
282
+ if device.type == "cuda":
283
+ torch.cuda.empty_cache()
284
 
285
+ return results, plot_paths, summary
286
 
287
+ except Exception as e:
288
+ logger.error(f"Training failed: {e}", exc_info=True)
289
+ return {"error": str(e)}, [], f"Training failed: {e}"
290
+
291
+ def create_interface():
292
+ """Create the Gradio interface."""
293
+
294
+ with gr.Blocks(title="TalkTuner Probe Training") as interface:
295
+ gr.Markdown("""
296
+ # 🎯 TalkTuner Probe Training Interface
297
 
298
+ Train demographic probes on Large Language Models to understand and control their outputs.
299
 
300
+ Based on ["Designing a Dashboard for Transparency and Control of Conversational AI"](https://arxiv.org/abs/2406.07882)
 
 
301
  """)
302
+
303
+ with gr.Tab("Environment Check"):
304
+ gr.Markdown("### System Information")
305
+ env_button = gr.Button("Check Environment", variant="primary")
306
+ env_output = gr.Dataframe(label="Environment Details")
307
+
308
+ env_button.click(
309
+ fn=check_environment,
310
+ inputs=[],
311
+ outputs=env_output
312
+ )
313
+
314
+ with gr.Tab("Train Probes"):
315
+ gr.Markdown("""
316
+ ### Configure Training
317
+
318
+ Select your model and training parameters below.
319
+ """)
320
+
321
+ with gr.Row():
322
+ model_dropdown = gr.Dropdown(
323
+ choices=[
324
+ "gpt2",
325
+ "meta-llama/Llama-2-7b-chat-hf",
326
+ "meta-llama/Llama-2-13b-chat-hf"
327
+ ],
328
+ value="gpt2",
329
+ label="Model",
330
+ info="Select the model to probe"
331
+ )
332
+
333
+ probe_type = gr.Radio(
334
+ choices=["reading", "controlling", "both"],
335
+ value="reading",
336
+ label="Probe Type",
337
+ info="Type of probes to train"
338
+ )
339
+
340
+ with gr.Row():
341
+ num_layers = gr.Slider(
342
+ minimum=1,
343
+ maximum=40,
344
+ value=5,
345
+ step=1,
346
+ label="Number of Layers",
347
+ info="How many layers to train (will be capped by model's actual layers)"
348
+ )
349
+
350
+ train_button = gr.Button("Start Training", variant="primary", size="lg")
351
+
352
+ with gr.Row():
353
+ results_json = gr.JSON(label="Training Results", visible=False)
354
+ summary_text = gr.Textbox(label="Summary", lines=15)
355
+
356
+ plot_output = gr.Image(label="Performance Visualization")
357
+
358
+ # Training action
359
+ train_button.click(
360
+ fn=train_probes,
361
+ inputs=[model_dropdown, probe_type, num_layers],
362
+ outputs=[results_json, plot_output, summary_text]
363
+ )
364
+
365
+ with gr.Tab("Instructions"):
366
+ gr.Markdown("""
367
+ ## How to Use This Interface
368
+
369
+ 1. **Check Environment**: Verify your hardware capabilities in the Environment Check tab
370
+ 2. **Select Model**: Choose from GPT-2 (fastest) or Llama models (more accurate)
371
+ 3. **Configure Training**: Set probe type and number of layers
372
+ 4. **Start Training**: Click the button and wait for results
373
+ 5. **View Results**: Check the visualization and summary
374
+
375
+ ## Hardware Recommendations
376
+
377
+ - **GPT-2**: CPU Basic or T4 Small
378
+ - **Llama-2-7b**: T4 Small or A10G
379
+ - **Llama-2-13b**: A10G or A100
380
+
381
+ ## Training Time Estimates
382
+
383
+ - GPT-2 (5 layers): ~2-5 minutes
384
+ - Llama-2-7b (5 layers): ~10-15 minutes
385
+ - Llama-2-13b (5 layers): ~20-30 minutes
386
+
387
+ ## Note
388
+
389
+ This interface uses synthetic data for demonstration. For production use,
390
+ upload real conversation datasets to the Space's data folder.
391
+ """)
392
+
393
+ return interface
394
 
395
+ # Create and launch the interface
396
  if __name__ == "__main__":
397
+ interface = create_interface()
398
+ interface.launch(
399
+ server_name="0.0.0.0" if IS_HF_SPACE else "127.0.0.1",
400
+ share=not IS_HF_SPACE
401
+ )