File size: 5,011 Bytes
71a764a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Import our modules
from src.models.llm_manager import LLMManager
from src.models.similarity_calculator import SimilarityCalculator
from src.models.data_manager import DataManager
from src.scoring.scorer import Scorer
from src.scoring.statistics import StatisticsCalculator
from src.scoring.achievements import AchievementSystem
from src.session.session_manager import SessionManager
from src.ui.template_renderer import TemplateRenderer
from src.ui.page_handlers import PageHandlers
from src.ui.interface_builder import InterfaceBuilder
from src.config.settings import DEFAULT_SERVER_NAME, DEFAULT_SERVER_PORT, DEFAULT_SHARE

class CollaborativeDecodingApp:
    def __init__(self):
        self.current_prompt = None
        
        # Initialize managers
        self.llm_manager = LLMManager()
        self.similarity_calculator = SimilarityCalculator()
        self.data_manager = DataManager()
        self.scorer = Scorer()
        self.statistics_calculator = StatisticsCalculator()
        self.achievement_system = AchievementSystem()
        self.session_manager = SessionManager(
            data_manager=self.data_manager,
            achievement_system=self.achievement_system
        )
        
        # Initialize UI components
        self.template_renderer = TemplateRenderer()
        self.page_handlers = PageHandlers(self)
        self.interface_builder = InterfaceBuilder(self, self.page_handlers)
        
        # Use session manager's session ID
        self.session_id = self.session_manager.session_id

    def load_data(self):
        self.data_manager.load_prompts_data()

    def load_models(self):
        self.llm_manager.load_models()
        self.similarity_calculator.load_model()

    def get_random_prompt(self):
        self.current_prompt = self.data_manager.get_random_prompt()
        return self.current_prompt

    def validate_user_input(self, user_input):
        return self.llm_manager.validate_user_input(user_input)

    def process_submission(self, user_input):
        """Process user submission and return results"""
        # Validation
        if not user_input.strip():
            return "Please enter some text to continue the response.", None, None, None, None, None, None
        
        if not self.validate_user_input(user_input):
            return "Please keep your input to 5 tokens or less.", None, None, None, None, None, None
        
        if not self.current_prompt:
            return "Error: No prompt loaded. Please refresh the page.", None, None, None, None, None, None
        
        # Generate response
        try:
            generated_response = self.llm_manager.generate_response_from_user_input(
                self.current_prompt["prompt"],
                self.current_prompt["llm_partial_response"], 
                user_input
            )
        except Exception as e:
            return f"Error generating response: {str(e)}", None, None, None, None, None, None
        
        # Calculate similarity score
        original_full = self.current_prompt["llm_full_response_original"]
        cosine_distance = self.similarity_calculator.compute_cosine_distance(original_full, generated_response)
        
        # Save interaction with token count
        num_user_tokens = self.llm_manager.count_tokens(user_input)
        self.data_manager.save_interaction(
            self.current_prompt, user_input, generated_response, 
            cosine_distance, self.session_id, num_user_tokens
        )
        
        # Calculate additional metrics for results display
        all_results = self.data_manager.load_results_data()
        prompt_results = self.data_manager.filter_results_by_partial_response(
            all_results, self.current_prompt["prompt"], self.current_prompt["llm_partial_response"]
        )
        
        # Calculate rank and percentile
        rank, percentile = self.scorer.calculate_rank_and_percentile(cosine_distance, prompt_results, num_user_tokens)
        
        # Calculate mean score (for legacy compatibility)
        scores = [r["cosine_distance"] for r in prompt_results if r["num_user_tokens"] == num_user_tokens]
        mean_score = sum(scores) / len(scores) if scores else cosine_distance
        
        # Create violin plot if there are enough results
        violin_plot = None
        if len(prompt_results) >= 3:
            violin_plot = self.statistics_calculator.create_violin_plot(prompt_results, cosine_distance, num_user_tokens)
        
        return generated_response, cosine_distance, rank, percentile, mean_score, violin_plot, prompt_results

    def create_interface(self):
        """Create the main Gradio interface"""
        return self.interface_builder.create_interface()


def main():
    app = CollaborativeDecodingApp()
    app.load_data()
    app.load_models()
    
    demo = app.create_interface()
    demo.launch(share=DEFAULT_SHARE, server_name=DEFAULT_SERVER_NAME, server_port=DEFAULT_SERVER_PORT)

if __name__ == "__main__":
    main()