Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Train reading and controlling probes for LLM attribute detection. | |
| This script trains linear probes on different layers of a language model to detect | |
| demographic attributes (age, gender, socioeconomic status, education level). | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import pickle | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Subset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from tqdm.auto import tqdm | |
| import sklearn.model_selection | |
| from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix | |
| import matplotlib.pyplot as plt | |
| # Import custom modules | |
| try: | |
| from src.dataset import TextDataset | |
| from src.probes import LinearProbeClassification | |
| from src.train_test_utils import train, test | |
| from src.losses import edl_mse_loss | |
| except ImportError as e: | |
| print(f"β ERROR: Failed to import required modules: {e}") | |
| print("Please ensure all required modules are in the correct location.") | |
| sys.exit(1) | |
| class TrainerConfig: | |
| """Configuration for training probes.""" | |
| learning_rate = 1e-3 | |
| betas = (0.9, 0.95) | |
| weight_decay = 0.1 # only applied on matmul weights | |
| def __init__(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| class ProbeTrainer: | |
| """Main class for training reading and controlling probes.""" | |
| def __init__(self, model_name: str = "meta-llama/Llama-2-13b-chat-hf", | |
| device: str = "cuda", use_auth_token: bool = True): | |
| """ | |
| Initialize the probe trainer. | |
| Args: | |
| model_name: HuggingFace model name | |
| device: Device to use for training | |
| use_auth_token: Whether to use auth token for model download | |
| """ | |
| self.device = device | |
| self.model_name = model_name | |
| # Configuration flags | |
| self.new_prompt_format = True | |
| self.residual_stream = True | |
| self.uncertainty = False | |
| self.logistic = True | |
| self.augmented = False | |
| self.remove_last_ai_response = True | |
| self.include_inst = True | |
| self.one_hot = True | |
| # Label mappings | |
| self.label_mappings = { | |
| "_age_": { | |
| "child": 0, | |
| "adolescent": 1, | |
| "adult": 2, | |
| "older adult": 3, | |
| }, | |
| "_gender_": { | |
| "male": 0, | |
| "female": 1, | |
| }, | |
| "_socioeco_": { | |
| "low": 0, | |
| "middle": 1, | |
| "high": 2 | |
| }, | |
| "_education_": { | |
| "someschool": 0, | |
| "highschool": 1, | |
| "collegemore": 2 | |
| } | |
| } | |
| self.prompt_translator = { | |
| "_age_": "age", | |
| "_gender_": "gender", | |
| "_socioeco_": "socioeconomic status", | |
| "_education_": "education level", | |
| } | |
| self.openai_dataset = { | |
| "_age_": "data/dataset/openai_age_1/", | |
| "_gender_": "data/dataset/openai_gender_1/", | |
| "_education_": "data/dataset/openai_education_1/", | |
| "_socioeco_": "data/dataset/openai_socioeconomic_1/", | |
| } | |
| # Dataset configurations | |
| self.dataset_configs = [ | |
| ("data/dataset/llama_age_1/", "_age_"), | |
| ("data/dataset/llama_gender_1/", "_gender_"), | |
| ("data/dataset/llama_socioeconomic_1/", "_socioeco_"), | |
| ("data/dataset/openai_education_1/", "_education_"), | |
| ] | |
| # Initialize model and tokenizer | |
| print(f"π Initializing ProbeTrainer with model: {model_name}") | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| """Initialize the tokenizer and model.""" | |
| try: | |
| print("π₯ Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| use_auth_token=True | |
| ) | |
| print("β Tokenizer loaded successfully") | |
| print("π₯ Loading model...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| use_auth_token=True | |
| ) | |
| if self.device == "cuda": | |
| print("π§ Moving model to GPU and setting to half precision...") | |
| self.model.half().cuda() | |
| self.model.eval() | |
| print("β Model loaded and ready") | |
| except Exception as e: | |
| print(f"β ERROR: Failed to initialize model: {e}") | |
| sys.exit(1) | |
| def _get_additional_datasets(self, label_idf: str, directory: str) -> List[str]: | |
| """Get additional datasets for training.""" | |
| additional_dataset = [] | |
| if label_idf == "_education_": | |
| additional_dataset = [] | |
| else: | |
| # Replace _1/ with _2/ for the second dataset | |
| additional_dataset = [ | |
| directory.replace("_1/", "_2/"), | |
| self.openai_dataset[label_idf] | |
| ] | |
| # Add extra datasets based on attribute type | |
| if label_idf == "_gender_": | |
| additional_dataset += [ | |
| "data/dataset/openai_gender_2/", | |
| "data/dataset/openai_gender_3/", | |
| "data/dataset/openai_gender_4", | |
| ] | |
| elif label_idf == "_education_": | |
| additional_dataset += [ | |
| "data/dataset/openai_education_three_classes_2/", | |
| "data/dataset/openai_education_three_classes_3/" | |
| ] | |
| elif label_idf == "_socioeco_": | |
| additional_dataset += [ | |
| "data/dataset/openai_socioeconomic_2/" | |
| ] | |
| elif label_idf == "_age_": | |
| additional_dataset += [ | |
| "data/dataset/openai_age_2/" | |
| ] | |
| return additional_dataset | |
| def _create_dataset(self, directory: str, label_idf: str, | |
| label_to_id: Dict, control_probe: bool = False) -> TextDataset: | |
| """Create a dataset for training.""" | |
| additional_datasets = self._get_additional_datasets(label_idf, directory) | |
| print(f" π Creating dataset from {directory}") | |
| print(f" π Additional datasets: {len(additional_datasets)} sources") | |
| try: | |
| dataset = TextDataset( | |
| directory, | |
| self.tokenizer, | |
| self.model, | |
| label_idf=label_idf, | |
| label_to_id=label_to_id, | |
| convert_to_llama2_format=True, | |
| additional_datas=additional_datasets, | |
| new_format=self.new_prompt_format, | |
| control_probe=control_probe, | |
| residual_stream=self.residual_stream, | |
| if_augmented=self.augmented, | |
| remove_last_ai_response=self.remove_last_ai_response, | |
| include_inst=self.include_inst, | |
| k=1, | |
| one_hot=False, | |
| last_tok_pos=-1 | |
| ) | |
| print(f" β Dataset created with {len(dataset)} samples") | |
| return dataset | |
| except Exception as e: | |
| print(f" β ERROR: Failed to create dataset: {e}") | |
| raise | |
| def _create_data_loaders(self, dataset: TextDataset) -> Tuple[DataLoader, DataLoader]: | |
| """Create train and test data loaders.""" | |
| train_size = int(0.8 * len(dataset)) | |
| test_size = len(dataset) - train_size | |
| print(f" π Splitting dataset: {train_size} train, {test_size} test") | |
| try: | |
| train_idx, val_idx = sklearn.model_selection.train_test_split( | |
| list(range(len(dataset))), | |
| test_size=test_size, | |
| train_size=train_size, | |
| random_state=12345, | |
| shuffle=True, | |
| stratify=dataset.labels, | |
| ) | |
| train_dataset = Subset(dataset, train_idx) | |
| test_dataset = Subset(dataset, val_idx) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| shuffle=True, | |
| pin_memory=True, | |
| batch_size=200, | |
| num_workers=1 | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| shuffle=False, | |
| pin_memory=True, | |
| batch_size=400, | |
| num_workers=1 | |
| ) | |
| print(f" β Data loaders created") | |
| return train_loader, test_loader | |
| except Exception as e: | |
| print(f" β ERROR: Failed to create data loaders: {e}") | |
| raise | |
| def _train_probe_for_layer(self, train_loader: DataLoader, test_loader: DataLoader, | |
| layer_num: int, num_classes: int, dict_name: str, | |
| save_dir: str, max_epochs: int = 50) -> Tuple[float, float, float]: | |
| """Train a probe for a specific layer.""" | |
| trainer_config = TrainerConfig() | |
| probe = LinearProbeClassification( | |
| probe_class=num_classes, | |
| device=self.device, | |
| input_dim=5120, | |
| logistic=self.logistic | |
| ) | |
| optimizer, scheduler = probe.configure_optimizers(trainer_config) | |
| if self.uncertainty: | |
| loss_func = edl_mse_loss | |
| else: | |
| loss_func = nn.BCELoss() | |
| best_acc = 0 | |
| final_test_acc = 0 | |
| final_train_acc = 0 | |
| for epoch in range(1, max_epochs + 1): | |
| verbosity = (epoch == max_epochs) | |
| # Training | |
| if self.uncertainty: | |
| train_results = train( | |
| probe, self.device, train_loader, optimizer, | |
| epoch, loss_func=loss_func, verbose_interval=None, | |
| verbose=verbosity, layer_num=layer_num, | |
| return_raw_outputs=True, epoch_num=epoch, | |
| num_classes=num_classes | |
| ) | |
| test_results = test( | |
| probe, self.device, test_loader, loss_func=loss_func, | |
| return_raw_outputs=True, verbose=verbosity, | |
| layer_num=layer_num, scheduler=scheduler, | |
| epoch_num=epoch, num_classes=num_classes | |
| ) | |
| else: | |
| train_results = train( | |
| probe, self.device, train_loader, optimizer, | |
| epoch, loss_func=loss_func, verbose_interval=None, | |
| verbose=verbosity, layer_num=layer_num, | |
| return_raw_outputs=True, one_hot=self.one_hot, | |
| num_classes=num_classes | |
| ) | |
| test_results = test( | |
| probe, self.device, test_loader, loss_func=loss_func, | |
| return_raw_outputs=True, verbose=verbosity, | |
| layer_num=layer_num, scheduler=scheduler, | |
| one_hot=self.one_hot, num_classes=num_classes | |
| ) | |
| if test_results[1] > best_acc: | |
| best_acc = test_results[1] | |
| save_path = f"{save_dir}/{dict_name}_probe_at_layer_{layer_num}.pth" | |
| torch.save(probe.state_dict(), save_path) | |
| if epoch == max_epochs: | |
| final_test_acc = test_results[1] | |
| final_train_acc = train_results[1] | |
| # Save final model | |
| final_path = f"{save_dir}/{dict_name}_probe_at_layer_{layer_num}_final.pth" | |
| torch.save(probe.state_dict(), final_path) | |
| # Generate confusion matrix | |
| if verbosity: | |
| try: | |
| cm = confusion_matrix(test_results[3], test_results[2]) | |
| cm_display = ConfusionMatrixDisplay( | |
| cm, | |
| display_labels=list(self.label_mappings[f"_{dict_name}_"].keys()) | |
| ).plot() | |
| plt.savefig(f"{save_dir}/{dict_name}_layer_{layer_num}_confusion.png") | |
| plt.close() | |
| except Exception as e: | |
| print(f" β οΈ Warning: Could not generate confusion matrix: {e}") | |
| return best_acc, final_test_acc, final_train_acc | |
| def train_probes(self, probe_type: str = "reading", num_layers: int = 41): | |
| """ | |
| Train probes for all attributes and layers. | |
| Args: | |
| probe_type: Type of probe to train ("reading" or "controlling") | |
| num_layers: Number of layers to train probes for | |
| """ | |
| print(f"\n{'='*80}") | |
| print(f"π― Training {probe_type.upper()} PROBES") | |
| print(f"{'='*80}\n") | |
| # Create output directory | |
| save_dir = f"probe_checkpoints/{probe_type}_probe" | |
| Path(save_dir).mkdir(parents=True, exist_ok=True) | |
| print(f"π Output directory: {save_dir}") | |
| accuracy_dict = {} | |
| control_probe = (probe_type == "controlling") | |
| for directory, label_idf in self.dataset_configs: | |
| dict_name = label_idf.strip("_") | |
| label_to_id = self.label_mappings[label_idf] | |
| print(f"\n{'-'*60}") | |
| print(f"π·οΈ Processing: {self.prompt_translator[label_idf].upper()}") | |
| print(f" Classes: {list(label_to_id.keys())}") | |
| print(f"{'-'*60}") | |
| try: | |
| # Create dataset | |
| dataset = self._create_dataset( | |
| directory, label_idf, label_to_id, control_probe | |
| ) | |
| # Create data loaders | |
| train_loader, test_loader = self._create_data_loaders(dataset) | |
| # Initialize accuracy tracking | |
| accuracy_dict[dict_name] = [] | |
| accuracy_dict[dict_name + "_final"] = [] | |
| accuracy_dict[dict_name + "_train"] = [] | |
| accs = [] | |
| final_accs = [] | |
| train_accs = [] | |
| # Train probes for each layer | |
| print(f"\n π Training probes for {num_layers} layers...") | |
| for layer_num in tqdm(range(num_layers), desc=f" Layers for {dict_name}"): | |
| try: | |
| print(f"\n Layer {layer_num}:") | |
| best_acc, final_test_acc, final_train_acc = self._train_probe_for_layer( | |
| train_loader, test_loader, layer_num, | |
| len(label_to_id), dict_name, save_dir | |
| ) | |
| accs.append(best_acc) | |
| final_accs.append(final_test_acc) | |
| train_accs.append(final_train_acc) | |
| print(f" π Best: {best_acc:.3f}, Final: {final_test_acc:.3f}, Train: {final_train_acc:.3f}") | |
| except Exception as e: | |
| print(f" β ERROR: Failed to train layer {layer_num}: {e}") | |
| accs.append(0) | |
| final_accs.append(0) | |
| train_accs.append(0) | |
| # Save accuracies | |
| accuracy_dict[dict_name] = accs | |
| accuracy_dict[dict_name + "_final"] = final_accs | |
| accuracy_dict[dict_name + "_train"] = train_accs | |
| # Save intermediate results | |
| with open(f"{save_dir}_experiment.pkl", "wb") as outfile: | |
| pickle.dump(accuracy_dict, outfile) | |
| print(f" πΎ Saved results to {save_dir}_experiment.pkl") | |
| # Clean up memory | |
| del dataset, train_loader, test_loader | |
| torch.cuda.empty_cache() | |
| print(f" π§Ή Cleaned up memory") | |
| except Exception as e: | |
| print(f" β ERROR: Failed to process {dict_name}: {e}") | |
| continue | |
| print(f"\n{'='*80}") | |
| print(f"β COMPLETED {probe_type.upper()} PROBE TRAINING") | |
| print(f"{'='*80}\n") | |
| # Print summary | |
| self._print_summary(accuracy_dict, probe_type) | |
| return accuracy_dict | |
| def _print_summary(self, accuracy_dict: Dict, probe_type: str): | |
| """Print a summary of training results.""" | |
| print(f"\nπ SUMMARY for {probe_type} probes:") | |
| print("-" * 40) | |
| for attribute in accuracy_dict: | |
| if not attribute.endswith("_final") and not attribute.endswith("_train"): | |
| best_accs = accuracy_dict[attribute] | |
| if best_accs: | |
| max_acc = max(best_accs) | |
| best_layer = best_accs.index(max_acc) | |
| avg_acc = sum(best_accs) / len(best_accs) | |
| print(f" {attribute:12s}: Best={max_acc:.3f} (layer {best_layer}), Avg={avg_acc:.3f}") | |
| def main(): | |
| """Main entry point for the script.""" | |
| parser = argparse.ArgumentParser(description="Train reading and controlling probes for LLM attribute detection") | |
| parser.add_argument("--probe-type", choices=["reading", "controlling", "both"], default="both", | |
| help="Type of probes to train") | |
| parser.add_argument("--model", default="meta-llama/Llama-2-13b-chat-hf", | |
| help="HuggingFace model to use") | |
| parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], | |
| help="Device to use for training") | |
| parser.add_argument("--num-layers", type=int, default=41, | |
| help="Number of layers to train probes for") | |
| parser.add_argument("--no-auth", action="store_true", | |
| help="Don't use authentication token") | |
| args = parser.parse_args() | |
| print(f""" | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β LLM Probe Training System β | |
| β β | |
| β Model: {args.model:50s} β | |
| β Device: {args.device:49s} β | |
| β Probe Type: {args.probe_type:45s} β | |
| β Layers: {args.num_layers:49d} β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """) | |
| start_time = time.time() | |
| try: | |
| # Initialize trainer | |
| trainer = ProbeTrainer( | |
| model_name=args.model, | |
| device=args.device, | |
| use_auth_token=not args.no_auth | |
| ) | |
| # Train probes | |
| if args.probe_type == "both": | |
| print("\nπ Training both reading and controlling probes...") | |
| reading_results = trainer.train_probes("reading", args.num_layers) | |
| controlling_results = trainer.train_probes("controlling", args.num_layers) | |
| elif args.probe_type == "reading": | |
| reading_results = trainer.train_probes("reading", args.num_layers) | |
| else: | |
| controlling_results = trainer.train_probes("controlling", args.num_layers) | |
| elapsed_time = time.time() - start_time | |
| print(f"\nβ±οΈ Total training time: {elapsed_time/60:.2f} minutes") | |
| print("β Training completed successfully!") | |
| except KeyboardInterrupt: | |
| print("\n\nβ οΈ Training interrupted by user") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"\nβ FATAL ERROR: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |