Spaces:
Sleeping
Sleeping
File size: 5,011 Bytes
71a764a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# Import our modules
from src.models.llm_manager import LLMManager
from src.models.similarity_calculator import SimilarityCalculator
from src.models.data_manager import DataManager
from src.scoring.scorer import Scorer
from src.scoring.statistics import StatisticsCalculator
from src.scoring.achievements import AchievementSystem
from src.session.session_manager import SessionManager
from src.ui.template_renderer import TemplateRenderer
from src.ui.page_handlers import PageHandlers
from src.ui.interface_builder import InterfaceBuilder
from src.config.settings import DEFAULT_SERVER_NAME, DEFAULT_SERVER_PORT, DEFAULT_SHARE
class CollaborativeDecodingApp:
def __init__(self):
self.current_prompt = None
# Initialize managers
self.llm_manager = LLMManager()
self.similarity_calculator = SimilarityCalculator()
self.data_manager = DataManager()
self.scorer = Scorer()
self.statistics_calculator = StatisticsCalculator()
self.achievement_system = AchievementSystem()
self.session_manager = SessionManager(
data_manager=self.data_manager,
achievement_system=self.achievement_system
)
# Initialize UI components
self.template_renderer = TemplateRenderer()
self.page_handlers = PageHandlers(self)
self.interface_builder = InterfaceBuilder(self, self.page_handlers)
# Use session manager's session ID
self.session_id = self.session_manager.session_id
def load_data(self):
self.data_manager.load_prompts_data()
def load_models(self):
self.llm_manager.load_models()
self.similarity_calculator.load_model()
def get_random_prompt(self):
self.current_prompt = self.data_manager.get_random_prompt()
return self.current_prompt
def validate_user_input(self, user_input):
return self.llm_manager.validate_user_input(user_input)
def process_submission(self, user_input):
"""Process user submission and return results"""
# Validation
if not user_input.strip():
return "Please enter some text to continue the response.", None, None, None, None, None, None
if not self.validate_user_input(user_input):
return "Please keep your input to 5 tokens or less.", None, None, None, None, None, None
if not self.current_prompt:
return "Error: No prompt loaded. Please refresh the page.", None, None, None, None, None, None
# Generate response
try:
generated_response = self.llm_manager.generate_response_from_user_input(
self.current_prompt["prompt"],
self.current_prompt["llm_partial_response"],
user_input
)
except Exception as e:
return f"Error generating response: {str(e)}", None, None, None, None, None, None
# Calculate similarity score
original_full = self.current_prompt["llm_full_response_original"]
cosine_distance = self.similarity_calculator.compute_cosine_distance(original_full, generated_response)
# Save interaction with token count
num_user_tokens = self.llm_manager.count_tokens(user_input)
self.data_manager.save_interaction(
self.current_prompt, user_input, generated_response,
cosine_distance, self.session_id, num_user_tokens
)
# Calculate additional metrics for results display
all_results = self.data_manager.load_results_data()
prompt_results = self.data_manager.filter_results_by_partial_response(
all_results, self.current_prompt["prompt"], self.current_prompt["llm_partial_response"]
)
# Calculate rank and percentile
rank, percentile = self.scorer.calculate_rank_and_percentile(cosine_distance, prompt_results, num_user_tokens)
# Calculate mean score (for legacy compatibility)
scores = [r["cosine_distance"] for r in prompt_results if r["num_user_tokens"] == num_user_tokens]
mean_score = sum(scores) / len(scores) if scores else cosine_distance
# Create violin plot if there are enough results
violin_plot = None
if len(prompt_results) >= 3:
violin_plot = self.statistics_calculator.create_violin_plot(prompt_results, cosine_distance, num_user_tokens)
return generated_response, cosine_distance, rank, percentile, mean_score, violin_plot, prompt_results
def create_interface(self):
"""Create the main Gradio interface"""
return self.interface_builder.create_interface()
def main():
app = CollaborativeDecodingApp()
app.load_data()
app.load_models()
demo = app.create_interface()
demo.launch(share=DEFAULT_SHARE, server_name=DEFAULT_SERVER_NAME, server_port=DEFAULT_SERVER_PORT)
if __name__ == "__main__":
main() |