lsatone / backend /ultra_compact_enhancer.py
3v324v23's picture
Update Comic123 with local comic folder files
83e35a7
"""
Ultra Compact Image Enhancer for Extreme Memory Constraints
Designed for RTX 3050 Laptop with strict <1GB VRAM limit
"""
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import gc
class UltraCompactESRGAN(nn.Module):
"""Ultra lightweight ESRGAN - only 200MB VRAM usage"""
def __init__(self, scale=2):
super().__init__()
self.scale = scale
# Ultra compact architecture
nf = 24 # Even smaller feature channels
self.conv1 = nn.Conv2d(3, nf, 3, 1, 1)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1)
self.conv3 = nn.Conv2d(nf, nf, 3, 1, 1)
# Pixel shuffle for upsampling
self.upscale = nn.Sequential(
nn.Conv2d(nf, 3 * scale * scale, 3, 1, 1),
nn.PixelShuffle(scale)
)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
# Simple forward pass
x1 = self.act(self.conv1(x))
x2 = self.act(self.conv2(x1))
x3 = self.conv3(x2)
x = x1 + x3 # Skip connection
x = self.upscale(x)
return x
class MemorySafeEnhancer:
"""Memory-safe enhancer that guarantees <1GB VRAM usage"""
def __init__(self):
self.device = self._setup_device()
self.model = None
self.tile_size = 64 # Very small tiles
self.scale = 2 # 2x max for 2K output
# Load model
self._load_model()
def _setup_device(self):
"""Setup device with strict memory limits"""
if torch.cuda.is_available():
# Clear any existing allocations
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Set strict memory limit
torch.cuda.set_per_process_memory_fraction(0.3) # Only 30% of VRAM
device = torch.device('cuda')
print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
# Print available memory
total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"📊 Total VRAM: {total:.1f}GB, Using max: {total*0.3:.1f}GB")
else:
device = torch.device('cpu')
print("💻 Using CPU")
return device
def _load_model(self):
"""Load ultra compact model"""
try:
print("🔄 Loading ultra-compact model...")
self.model = UltraCompactESRGAN(scale=self.scale)
self.model = self.model.to(self.device)
self.model.eval()
# Use half precision on GPU
if self.device.type == 'cuda':
self.model = self.model.half()
# Calculate model size
param_size = sum(p.numel() for p in self.model.parameters())
model_mb = param_size * 2 / (1024**2) # 2 bytes for FP16
print(f"✅ Model loaded: {model_mb:.1f}MB")
except Exception as e:
print(f"❌ Model loading failed: {e}")
self.model = None
def enhance_image(self, image_path: str, output_path: str = None) -> str:
"""Enhance image with guaranteed low memory usage"""
if output_path is None:
output_path = image_path.replace('.', '_enhanced.')
print(f"🎨 Enhancing {os.path.basename(image_path)}...")
try:
# Read image
img = cv2.imread(image_path)
if img is None:
print(f"❌ Failed to read image")
return image_path
h, w = img.shape[:2]
print(f" Input: {w}x{h}")
# Use fallback for very large images
if h > 2048 or w > 2048:
print(" ⚠️ Large image, using CPU fallback")
enhanced = self._cpu_upscale(img)
elif self.model is not None:
enhanced = self._enhance_with_model(img)
else:
enhanced = self._cpu_upscale(img)
# Ensure 2K limit
h, w = enhanced.shape[:2]
if w > 2048 or h > 1080:
scale = min(2048/w, 1080/h)
new_w = int(w * scale)
new_h = int(h * scale)
enhanced = cv2.resize(enhanced, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
print(f" 📐 Resizing from {w}x{h} to {new_w}x{new_h} (2K limit)")
# Save result
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95])
new_h, new_w = enhanced.shape[:2]
print(f" ✅ Output: {new_w}x{new_h}")
# Force memory cleanup
self._cleanup_memory()
return output_path
except Exception as e:
print(f" ❌ Enhancement failed: {e}")
# Try CPU fallback
try:
img = cv2.imread(image_path)
enhanced = self._cpu_upscale(img)
cv2.imwrite(output_path, enhanced)
return output_path
except:
return image_path
def _enhance_with_model(self, img):
"""Enhance using model with extreme memory safety"""
h, w = img.shape[:2]
# Output image (on CPU to save GPU memory)
output = np.zeros((h * self.scale, w * self.scale, 3), dtype=np.uint8)
# Process in very small tiles
tile_size = self.tile_size
print(f" Processing {tile_size}x{tile_size} tiles...")
for y in range(0, h, tile_size):
for x in range(0, w, tile_size):
# Extract tile
y_end = min(y + tile_size, h)
x_end = min(x + tile_size, w)
tile = img[y:y_end, x:x_end]
# Skip if tile is too small
if tile.shape[0] < 4 or tile.shape[1] < 4:
continue
try:
# Process tile
enhanced_tile = self._process_single_tile(tile)
# Place in output
out_y = y * self.scale
out_x = x * self.scale
out_y_end = out_y + enhanced_tile.shape[0]
out_x_end = out_x + enhanced_tile.shape[1]
output[out_y:out_y_end, out_x:out_x_end] = enhanced_tile
except Exception as e:
# If tile fails, use CPU upscale for that tile
fallback = cv2.resize(tile, (tile.shape[1]*self.scale, tile.shape[0]*self.scale),
interpolation=cv2.INTER_CUBIC)
out_y = y * self.scale
out_x = x * self.scale
output[out_y:out_y+fallback.shape[0], out_x:out_x+fallback.shape[1]] = fallback
# Force memory cleanup after each tile
if self.device.type == 'cuda':
torch.cuda.empty_cache()
return output
def _process_single_tile(self, tile):
"""Process a single tile with proper error handling"""
# Convert to tensor
tile_rgb = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB)
tile_norm = tile_rgb.astype(np.float32) / 255.0
# Create tensor with correct shape
tile_tensor = torch.from_numpy(tile_norm).permute(2, 0, 1).unsqueeze(0)
tile_tensor = tile_tensor.to(self.device)
# Convert to half precision if using GPU
if self.device.type == 'cuda':
tile_tensor = tile_tensor.half()
# Process
with torch.no_grad():
enhanced_tensor = self.model(tile_tensor)
# Convert back to numpy
enhanced = enhanced_tensor.squeeze(0).permute(1, 2, 0)
enhanced = enhanced.cpu().float().numpy()
enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR)
# Clean up tensors
del tile_tensor, enhanced_tensor
return enhanced
def _cpu_upscale(self, img):
"""CPU-only upscaling fallback"""
print(" 📈 Using CPU upscaling...")
# High-quality CPU upscaling (max 2K)
h, w = img.shape[:2]
scale_factor = min(self.scale, 2048/w, 1080/h)
new_w = int(w * scale_factor)
new_h = int(h * scale_factor)
# Use multiple interpolation methods and blend
cubic = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
lanczos = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
# Blend for better quality
result = cv2.addWeighted(cubic, 0.5, lanczos, 0.5, 0)
# Mild sharpening (properly normalized)
kernel = np.array([[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]], dtype=np.float32)
result = cv2.filter2D(result, -1, kernel)
return result
def _cleanup_memory(self):
"""Aggressive memory cleanup"""
gc.collect()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
def get_memory_usage(self):
"""Get current memory usage"""
if self.device.type == 'cuda':
allocated = torch.cuda.memory_allocated() / (1024**2)
reserved = torch.cuda.memory_reserved() / (1024**2)
return f"Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB"
return "Using CPU"
# Global instance
_memory_safe_enhancer = None
def get_memory_safe_enhancer():
"""Get or create memory-safe enhancer"""
global _memory_safe_enhancer
if _memory_safe_enhancer is None:
_memory_safe_enhancer = MemorySafeEnhancer()
return _memory_safe_enhancer
# Simple API
def enhance_image_safe(image_path: str, output_path: str = None) -> str:
"""Enhance image with guaranteed <1GB VRAM usage"""
enhancer = get_memory_safe_enhancer()
return enhancer.enhance_image(image_path, output_path)