Spaces:
Running
Running
File size: 8,489 Bytes
6fc3143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
Code Fixer - Auto-fixes common issues in generated code
"""
import re
from typing import List, Tuple
from .code_postprocessor import post_process_code, fix_undefined_colors
from .code_validator import CodeValidator
class CodeFixer:
"""Auto-fixes common code issues"""
def __init__(self):
self.validator = CodeValidator()
def auto_fix(self, code: str, errors: List[str]) -> str:
"""
Attempt to auto-fix code based on errors.
Args:
code: Code to fix
errors: List of error messages
Returns:
Fixed code
"""
fixed_code = code
# Apply existing post-processor fixes
fixed_code = post_process_code(fixed_code)
# Fix missing imports
fixed_code = self._fix_missing_imports(fixed_code, errors)
# Fix undefined colors (already in post_process_code, but ensure it's applied)
fixed_code = fix_undefined_colors(fixed_code)
# Fix voiceover setup
fixed_code = self._fix_voiceover_setup(fixed_code, errors)
# Fix common syntax issues
fixed_code = self._fix_syntax_issues(fixed_code)
return fixed_code
def _fix_missing_imports(self, code: str, errors: List[str]) -> str:
"""Add missing imports"""
imports_to_add = []
if 'Missing required import: from manim import *' in str(errors):
if 'from manim import *' not in code:
imports_to_add.append('from manim import *')
if 'Missing required import: from manimator.scene.voiceover_scene import VoiceoverScene' in str(errors):
if 'from manimator.scene.voiceover_scene import VoiceoverScene' not in code:
imports_to_add.append('from manimator.scene.voiceover_scene import VoiceoverScene')
if 'Missing required import: from manimator.services.voiceover import SimpleElevenLabsService' in str(errors):
if 'from manimator.services.voiceover import SimpleElevenLabsService' not in code:
imports_to_add.append('from manimator.services.voiceover import SimpleElevenLabsService')
if 'Missing required import: from pathlib import Path' in str(errors):
if 'from pathlib import Path' not in code:
imports_to_add.append('from pathlib import Path')
if imports_to_add:
# Find where to insert imports (after existing imports or at the top)
lines = code.split('\n')
insert_idx = 0
# Find last import line
for i, line in enumerate(lines):
if line.strip().startswith('import ') or line.strip().startswith('from '):
insert_idx = i + 1
# Insert new imports
for imp in imports_to_add:
if imp not in code:
lines.insert(insert_idx, imp)
insert_idx += 1
code = '\n'.join(lines)
return code
def _fix_voiceover_setup(self, code: str, errors: List[str]) -> str:
"""Fix voiceover service setup"""
if 'Voiceover service not initialized' in str(errors):
# Find construct method
construct_match = re.search(r'def construct\(self\):\s*\n', code)
if construct_match:
construct_pos = construct_match.end()
# Check if setup already exists
if 'set_speech_service' not in code[construct_pos:construct_pos+500]:
# Find first non-empty line after construct
lines = code.split('\n')
construct_line_idx = None
for i, line in enumerate(lines):
if 'def construct(self):' in line:
construct_line_idx = i
break
if construct_line_idx is not None:
# Find insertion point (after construct, before other code)
insert_idx = construct_line_idx + 1
while insert_idx < len(lines) and (not lines[insert_idx].strip() or lines[insert_idx].strip().startswith('#')):
insert_idx += 1
# Get indentation
if insert_idx < len(lines):
indent = len(lines[insert_idx]) - len(lines[insert_idx].lstrip())
else:
indent = 8
# Insert voiceover setup with SimpleElevenLabsService
setup_line = ' ' * indent + 'self.set_speech_service(SimpleElevenLabsService(voice_id="Rachel", cache_dir=Path("media/voiceover/elevenlabs")))'
lines.insert(insert_idx, setup_line)
code = '\n'.join(lines)
return code
def _fix_syntax_issues(self, code: str) -> str:
"""Fix common syntax issues"""
# Fix common indentation issues
# Fix missing colons after if/for/while
code = re.sub(r'(if|for|while|def|class)\s+[^:]+$', r'\1:', code, flags=re.MULTILINE)
# Fix double colons
code = re.sub(r'::', ':', code)
return code
def fix_and_validate(self, code: str, max_attempts: int = 3) -> Tuple[str, bool, List[str]]:
"""
Fix code and validate until valid or max attempts reached.
Args:
code: Code to fix
max_attempts: Maximum fix attempts
Returns:
Tuple of (fixed_code, is_valid, remaining_errors)
"""
current_code = code
for attempt in range(max_attempts):
is_valid, errors = self.validator.validate(current_code)
if is_valid:
return (current_code, True, [])
# Get fixable errors
fixable_errors = self.validator.get_fixable_errors(errors)
if not fixable_errors:
# No more fixable errors
return (current_code, False, errors)
# Attempt to fix
current_code = self.auto_fix(current_code, fixable_errors)
# Final validation
is_valid, final_errors = self.validator.validate(current_code)
return (current_code, is_valid, final_errors)
def fix_runtime_error(self, code: str, error_message: str) -> str:
"""
Ask LLM to fix code based on a runtime error message.
"""
import litellm
import os
model = os.getenv("CODE_GEN_MODEL", "gpt-4o")
messages = [
{
"role": "system",
"content": "You are a Manim expert. Your task is to fix Python code that failed to render. You will be given the code and the error message. Return ONLY the fixed Python code. Do not wrap in markdown blocks if possible, or use ```python blocks."
},
{
"role": "user",
"content": f"The following Manim code failed with an error:\n\nERROR:\n{error_message}\n\nCODE:\n{code}\n\nFix the error and return the full corrected code."
}
]
try:
response = litellm.completion(model=model, messages=messages)
fixed_code = response.choices[0].message.content
# Extract code if wrapped in markdown
if "```python" in fixed_code:
import re
match = re.search(r"```python\n(.*?)```", fixed_code, re.DOTALL)
if match:
fixed_code = match.group(1).strip()
elif "```" in fixed_code:
import re
match = re.search(r"```\n(.*?)```", fixed_code, re.DOTALL)
if match:
fixed_code = match.group(1).strip()
# Apply post-processing to the fixed code as well
return post_process_code(fixed_code)
except Exception as e:
# If LLM fix fails, return original code to avoid crashing the fixer
return code
|