Alon Albalak
major update: data saved to hf, user sessions maintain separation, fixed scoring bug
36d5e94
# Import our modules
from src.models.llm_manager import LLMManager
from src.models.similarity_calculator import SimilarityCalculator
from src.models.data_manager import DataManager
from src.scoring.scorer import Scorer
from src.scoring.statistics import StatisticsCalculator
from src.scoring.achievements import AchievementSystem
from src.ui.template_renderer import TemplateRenderer
from src.ui.page_handlers import PageHandlers
from src.ui.interface_builder import InterfaceBuilder
class CollaborativeDecodingApp:
def __init__(self):
# Initialize shared managers (no per-user state)
self.llm_manager = LLMManager()
self.similarity_calculator = SimilarityCalculator()
self.data_manager = DataManager()
self.scorer = Scorer()
self.statistics_calculator = StatisticsCalculator()
self.achievement_system = AchievementSystem()
# Initialize UI components
self.template_renderer = TemplateRenderer()
self.page_handlers = PageHandlers(self)
self.interface_builder = InterfaceBuilder(self, self.page_handlers)
def load_data(self):
self.data_manager.load_prompts_data()
def load_models(self):
self.llm_manager.load_models()
self.similarity_calculator.load_model()
def get_random_prompt(self):
"""Get a random prompt from the dataset"""
return self.data_manager.get_random_prompt()
def validate_user_input(self, user_input):
"""Validate user input using LLM manager"""
return self.llm_manager.validate_user_input(user_input)
def process_submission(self, user_input, current_prompt, session_id):
"""Process user submission and return results with updated state"""
# Validation
if not user_input.strip():
return "Please enter some text to continue the response.", None, None, None, None, None, None
if not self.validate_user_input(user_input):
return "Please keep your input to 5 tokens or less.", None, None, None, None, None, None
if not current_prompt:
return "Error: No prompt loaded. Please refresh the page.", None, None, None, None, None, None
# Generate response
try:
generated_response = self.llm_manager.generate_response_from_user_input(
current_prompt["prompt"],
current_prompt["llm_partial_response"],
user_input
)
except Exception as e:
return f"Error generating response: {str(e)}", None, None, None, None, None, None
# Calculate similarity score
original_full = current_prompt["llm_full_response_original"]
cosine_distance = self.similarity_calculator.compute_cosine_distance(original_full, generated_response)
# Save interaction with token count
num_user_tokens = self.llm_manager.count_tokens(user_input)
self.data_manager.save_interaction_to_hf(
current_prompt, user_input, generated_response,
cosine_distance, session_id, num_user_tokens
)
# Calculate additional metrics for results display
all_results = self.data_manager.get_results()
prompt_results = self.data_manager.filter_results_by_partial_response(
all_results, current_prompt["prompt"], current_prompt["llm_partial_response"]
)
# Calculate rank and percentile
rank, percentile = self.scorer.calculate_rank_and_percentile(cosine_distance, prompt_results, num_user_tokens)
# Calculate mean score (for legacy compatibility)
scores = [r["cosine_distance"] for r in prompt_results if r["num_user_tokens"] == num_user_tokens]
mean_score = sum(scores) / len(scores) if scores else cosine_distance
# Create violin plot if there are enough results
violin_plot = None
if len(prompt_results) >= 3:
violin_plot = self.statistics_calculator.create_violin_plot(prompt_results, cosine_distance, num_user_tokens)
return generated_response, cosine_distance, rank, percentile, mean_score, violin_plot, prompt_results
def create_interface(self):
"""Create the main Gradio interface"""
return self.interface_builder.create_interface()
def main():
app = CollaborativeDecodingApp()
app.load_data()
app.load_models()
demo = app.create_interface()
demo.launch()
if __name__ == "__main__":
main()