File size: 8,489 Bytes
6fc3143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Code Fixer - Auto-fixes common issues in generated code
"""

import re
from typing import List, Tuple
from .code_postprocessor import post_process_code, fix_undefined_colors
from .code_validator import CodeValidator


class CodeFixer:
    """Auto-fixes common code issues"""
    
    def __init__(self):
        self.validator = CodeValidator()
    
    def auto_fix(self, code: str, errors: List[str]) -> str:
        """
        Attempt to auto-fix code based on errors.
        
        Args:
            code: Code to fix
            errors: List of error messages
            
        Returns:
            Fixed code
        """
        fixed_code = code
        
        # Apply existing post-processor fixes
        fixed_code = post_process_code(fixed_code)
        
        # Fix missing imports
        fixed_code = self._fix_missing_imports(fixed_code, errors)
        
        # Fix undefined colors (already in post_process_code, but ensure it's applied)
        fixed_code = fix_undefined_colors(fixed_code)
        
        # Fix voiceover setup
        fixed_code = self._fix_voiceover_setup(fixed_code, errors)
        
        # Fix common syntax issues
        fixed_code = self._fix_syntax_issues(fixed_code)
        
        return fixed_code
    
    def _fix_missing_imports(self, code: str, errors: List[str]) -> str:
        """Add missing imports"""
        imports_to_add = []
        
        if 'Missing required import: from manim import *' in str(errors):
            if 'from manim import *' not in code:
                imports_to_add.append('from manim import *')
        
        if 'Missing required import: from manimator.scene.voiceover_scene import VoiceoverScene' in str(errors):
            if 'from manimator.scene.voiceover_scene import VoiceoverScene' not in code:
                imports_to_add.append('from manimator.scene.voiceover_scene import VoiceoverScene')
        
        if 'Missing required import: from manimator.services.voiceover import SimpleElevenLabsService' in str(errors):
            if 'from manimator.services.voiceover import SimpleElevenLabsService' not in code:
                imports_to_add.append('from manimator.services.voiceover import SimpleElevenLabsService')
        
        if 'Missing required import: from pathlib import Path' in str(errors):
            if 'from pathlib import Path' not in code:
                imports_to_add.append('from pathlib import Path')
        
        if imports_to_add:
            # Find where to insert imports (after existing imports or at the top)
            lines = code.split('\n')
            insert_idx = 0
            
            # Find last import line
            for i, line in enumerate(lines):
                if line.strip().startswith('import ') or line.strip().startswith('from '):
                    insert_idx = i + 1
            
            # Insert new imports
            for imp in imports_to_add:
                if imp not in code:
                    lines.insert(insert_idx, imp)
                    insert_idx += 1
            
            code = '\n'.join(lines)
        
        return code
    
    def _fix_voiceover_setup(self, code: str, errors: List[str]) -> str:
        """Fix voiceover service setup"""
        if 'Voiceover service not initialized' in str(errors):
            # Find construct method
            construct_match = re.search(r'def construct\(self\):\s*\n', code)
            if construct_match:
                construct_pos = construct_match.end()
                
                # Check if setup already exists
                if 'set_speech_service' not in code[construct_pos:construct_pos+500]:
                    # Find first non-empty line after construct
                    lines = code.split('\n')
                    construct_line_idx = None
                    
                    for i, line in enumerate(lines):
                        if 'def construct(self):' in line:
                            construct_line_idx = i
                            break
                    
                    if construct_line_idx is not None:
                        # Find insertion point (after construct, before other code)
                        insert_idx = construct_line_idx + 1
                        while insert_idx < len(lines) and (not lines[insert_idx].strip() or lines[insert_idx].strip().startswith('#')):
                            insert_idx += 1
                        
                        # Get indentation
                        if insert_idx < len(lines):
                            indent = len(lines[insert_idx]) - len(lines[insert_idx].lstrip())
                        else:
                            indent = 8
                        
                        # Insert voiceover setup with SimpleElevenLabsService
                        setup_line = ' ' * indent + 'self.set_speech_service(SimpleElevenLabsService(voice_id="Rachel", cache_dir=Path("media/voiceover/elevenlabs")))'
                        lines.insert(insert_idx, setup_line)
                        code = '\n'.join(lines)
        
        return code
    
    def _fix_syntax_issues(self, code: str) -> str:
        """Fix common syntax issues"""
        # Fix common indentation issues
        # Fix missing colons after if/for/while
        code = re.sub(r'(if|for|while|def|class)\s+[^:]+$', r'\1:', code, flags=re.MULTILINE)
        
        # Fix double colons
        code = re.sub(r'::', ':', code)
        
        return code
    
    def fix_and_validate(self, code: str, max_attempts: int = 3) -> Tuple[str, bool, List[str]]:
        """
        Fix code and validate until valid or max attempts reached.
        
        Args:
            code: Code to fix
            max_attempts: Maximum fix attempts
            
        Returns:
            Tuple of (fixed_code, is_valid, remaining_errors)
        """
        current_code = code
        
        for attempt in range(max_attempts):
            is_valid, errors = self.validator.validate(current_code)
            
            if is_valid:
                return (current_code, True, [])
            
            # Get fixable errors
            fixable_errors = self.validator.get_fixable_errors(errors)
            
            if not fixable_errors:
                # No more fixable errors
                return (current_code, False, errors)
            
            # Attempt to fix
            current_code = self.auto_fix(current_code, fixable_errors)
        
        # Final validation
        is_valid, final_errors = self.validator.validate(current_code)
        return (current_code, is_valid, final_errors)

    def fix_runtime_error(self, code: str, error_message: str) -> str:
        """
        Ask LLM to fix code based on a runtime error message.
        """
        import litellm
        import os
        
        model = os.getenv("CODE_GEN_MODEL", "gpt-4o")
        
        messages = [
            {
                "role": "system",
                "content": "You are a Manim expert. Your task is to fix Python code that failed to render. You will be given the code and the error message. Return ONLY the fixed Python code. Do not wrap in markdown blocks if possible, or use ```python blocks."
            },
            {
                "role": "user",
                "content": f"The following Manim code failed with an error:\n\nERROR:\n{error_message}\n\nCODE:\n{code}\n\nFix the error and return the full corrected code."
            }
        ]
        
        try:
            response = litellm.completion(model=model, messages=messages)
            fixed_code = response.choices[0].message.content
            
            # Extract code if wrapped in markdown
            if "```python" in fixed_code:
                import re
                match = re.search(r"```python\n(.*?)```", fixed_code, re.DOTALL)
                if match:
                    fixed_code = match.group(1).strip()
            elif "```" in fixed_code:
                 import re
                 match = re.search(r"```\n(.*?)```", fixed_code, re.DOTALL)
                 if match:
                     fixed_code = match.group(1).strip()
            
            # Apply post-processing to the fixed code as well
            return post_process_code(fixed_code)
            
        except Exception as e:
            # If LLM fix fails, return original code to avoid crashing the fixer
            return code