Spaces:
Sleeping
Sleeping
Alon Albalak
major update: data saved to hf, user sessions maintain separation, fixed scoring bug
36d5e94
| # Import our modules | |
| from src.models.llm_manager import LLMManager | |
| from src.models.similarity_calculator import SimilarityCalculator | |
| from src.models.data_manager import DataManager | |
| from src.scoring.scorer import Scorer | |
| from src.scoring.statistics import StatisticsCalculator | |
| from src.scoring.achievements import AchievementSystem | |
| from src.ui.template_renderer import TemplateRenderer | |
| from src.ui.page_handlers import PageHandlers | |
| from src.ui.interface_builder import InterfaceBuilder | |
| class CollaborativeDecodingApp: | |
| def __init__(self): | |
| # Initialize shared managers (no per-user state) | |
| self.llm_manager = LLMManager() | |
| self.similarity_calculator = SimilarityCalculator() | |
| self.data_manager = DataManager() | |
| self.scorer = Scorer() | |
| self.statistics_calculator = StatisticsCalculator() | |
| self.achievement_system = AchievementSystem() | |
| # Initialize UI components | |
| self.template_renderer = TemplateRenderer() | |
| self.page_handlers = PageHandlers(self) | |
| self.interface_builder = InterfaceBuilder(self, self.page_handlers) | |
| def load_data(self): | |
| self.data_manager.load_prompts_data() | |
| def load_models(self): | |
| self.llm_manager.load_models() | |
| self.similarity_calculator.load_model() | |
| def get_random_prompt(self): | |
| """Get a random prompt from the dataset""" | |
| return self.data_manager.get_random_prompt() | |
| def validate_user_input(self, user_input): | |
| """Validate user input using LLM manager""" | |
| return self.llm_manager.validate_user_input(user_input) | |
| def process_submission(self, user_input, current_prompt, session_id): | |
| """Process user submission and return results with updated state""" | |
| # Validation | |
| if not user_input.strip(): | |
| return "Please enter some text to continue the response.", None, None, None, None, None, None | |
| if not self.validate_user_input(user_input): | |
| return "Please keep your input to 5 tokens or less.", None, None, None, None, None, None | |
| if not current_prompt: | |
| return "Error: No prompt loaded. Please refresh the page.", None, None, None, None, None, None | |
| # Generate response | |
| try: | |
| generated_response = self.llm_manager.generate_response_from_user_input( | |
| current_prompt["prompt"], | |
| current_prompt["llm_partial_response"], | |
| user_input | |
| ) | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}", None, None, None, None, None, None | |
| # Calculate similarity score | |
| original_full = current_prompt["llm_full_response_original"] | |
| cosine_distance = self.similarity_calculator.compute_cosine_distance(original_full, generated_response) | |
| # Save interaction with token count | |
| num_user_tokens = self.llm_manager.count_tokens(user_input) | |
| self.data_manager.save_interaction_to_hf( | |
| current_prompt, user_input, generated_response, | |
| cosine_distance, session_id, num_user_tokens | |
| ) | |
| # Calculate additional metrics for results display | |
| all_results = self.data_manager.get_results() | |
| prompt_results = self.data_manager.filter_results_by_partial_response( | |
| all_results, current_prompt["prompt"], current_prompt["llm_partial_response"] | |
| ) | |
| # Calculate rank and percentile | |
| rank, percentile = self.scorer.calculate_rank_and_percentile(cosine_distance, prompt_results, num_user_tokens) | |
| # Calculate mean score (for legacy compatibility) | |
| scores = [r["cosine_distance"] for r in prompt_results if r["num_user_tokens"] == num_user_tokens] | |
| mean_score = sum(scores) / len(scores) if scores else cosine_distance | |
| # Create violin plot if there are enough results | |
| violin_plot = None | |
| if len(prompt_results) >= 3: | |
| violin_plot = self.statistics_calculator.create_violin_plot(prompt_results, cosine_distance, num_user_tokens) | |
| return generated_response, cosine_distance, rank, percentile, mean_score, violin_plot, prompt_results | |
| def create_interface(self): | |
| """Create the main Gradio interface""" | |
| return self.interface_builder.create_interface() | |
| def main(): | |
| app = CollaborativeDecodingApp() | |
| app.load_data() | |
| app.load_models() | |
| demo = app.create_interface() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |