Alon Albalak
initial commmit
71a764a
raw
history blame
5.01 kB
# Import our modules
from src.models.llm_manager import LLMManager
from src.models.similarity_calculator import SimilarityCalculator
from src.models.data_manager import DataManager
from src.scoring.scorer import Scorer
from src.scoring.statistics import StatisticsCalculator
from src.scoring.achievements import AchievementSystem
from src.session.session_manager import SessionManager
from src.ui.template_renderer import TemplateRenderer
from src.ui.page_handlers import PageHandlers
from src.ui.interface_builder import InterfaceBuilder
from src.config.settings import DEFAULT_SERVER_NAME, DEFAULT_SERVER_PORT, DEFAULT_SHARE
class CollaborativeDecodingApp:
def __init__(self):
self.current_prompt = None
# Initialize managers
self.llm_manager = LLMManager()
self.similarity_calculator = SimilarityCalculator()
self.data_manager = DataManager()
self.scorer = Scorer()
self.statistics_calculator = StatisticsCalculator()
self.achievement_system = AchievementSystem()
self.session_manager = SessionManager(
data_manager=self.data_manager,
achievement_system=self.achievement_system
)
# Initialize UI components
self.template_renderer = TemplateRenderer()
self.page_handlers = PageHandlers(self)
self.interface_builder = InterfaceBuilder(self, self.page_handlers)
# Use session manager's session ID
self.session_id = self.session_manager.session_id
def load_data(self):
self.data_manager.load_prompts_data()
def load_models(self):
self.llm_manager.load_models()
self.similarity_calculator.load_model()
def get_random_prompt(self):
self.current_prompt = self.data_manager.get_random_prompt()
return self.current_prompt
def validate_user_input(self, user_input):
return self.llm_manager.validate_user_input(user_input)
def process_submission(self, user_input):
"""Process user submission and return results"""
# Validation
if not user_input.strip():
return "Please enter some text to continue the response.", None, None, None, None, None, None
if not self.validate_user_input(user_input):
return "Please keep your input to 5 tokens or less.", None, None, None, None, None, None
if not self.current_prompt:
return "Error: No prompt loaded. Please refresh the page.", None, None, None, None, None, None
# Generate response
try:
generated_response = self.llm_manager.generate_response_from_user_input(
self.current_prompt["prompt"],
self.current_prompt["llm_partial_response"],
user_input
)
except Exception as e:
return f"Error generating response: {str(e)}", None, None, None, None, None, None
# Calculate similarity score
original_full = self.current_prompt["llm_full_response_original"]
cosine_distance = self.similarity_calculator.compute_cosine_distance(original_full, generated_response)
# Save interaction with token count
num_user_tokens = self.llm_manager.count_tokens(user_input)
self.data_manager.save_interaction(
self.current_prompt, user_input, generated_response,
cosine_distance, self.session_id, num_user_tokens
)
# Calculate additional metrics for results display
all_results = self.data_manager.load_results_data()
prompt_results = self.data_manager.filter_results_by_partial_response(
all_results, self.current_prompt["prompt"], self.current_prompt["llm_partial_response"]
)
# Calculate rank and percentile
rank, percentile = self.scorer.calculate_rank_and_percentile(cosine_distance, prompt_results, num_user_tokens)
# Calculate mean score (for legacy compatibility)
scores = [r["cosine_distance"] for r in prompt_results if r["num_user_tokens"] == num_user_tokens]
mean_score = sum(scores) / len(scores) if scores else cosine_distance
# Create violin plot if there are enough results
violin_plot = None
if len(prompt_results) >= 3:
violin_plot = self.statistics_calculator.create_violin_plot(prompt_results, cosine_distance, num_user_tokens)
return generated_response, cosine_distance, rank, percentile, mean_score, violin_plot, prompt_results
def create_interface(self):
"""Create the main Gradio interface"""
return self.interface_builder.create_interface()
def main():
app = CollaborativeDecodingApp()
app.load_data()
app.load_models()
demo = app.create_interface()
demo.launch(share=DEFAULT_SHARE, server_name=DEFAULT_SERVER_NAME, server_port=DEFAULT_SERVER_PORT)
if __name__ == "__main__":
main()