lsatone / backend /small_model_enhancer.py
3v324v23's picture
Update Comic123 with local comic folder files
83e35a7
"""
Small Model AI Enhancer for Limited VRAM
Uses compact models that work with <1GB VRAM
"""
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import requests
from typing import Optional, Dict
import json
# Compact model architectures
class CARN(nn.Module):
"""Cascading Residual Network - Ultra lightweight (~1.6MB)"""
def __init__(self, scale=4):
super(CARN, self).__init__()
self.scale = scale
self.entry = nn.Conv2d(3, 64, 3, 1, 1)
# Cascading blocks (simplified)
self.b1 = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(64, 64, 3, 1, 1)
)
self.upsample = nn.Sequential(
nn.Conv2d(64, 3 * scale * scale, 3, 1, 1),
nn.PixelShuffle(scale)
)
def forward(self, x):
x = self.entry(x)
x = x + self.b1(x)
x = self.upsample(x)
return x
class MSRN(nn.Module):
"""Multi-scale Residual Network - Lightweight (~6MB)"""
def __init__(self, scale=4):
super(MSRN, self).__init__()
self.scale = scale
self.conv_input = nn.Conv2d(3, 64, 3, 1, 1)
# Multi-scale blocks
self.msrb = nn.Sequential(
nn.Conv2d(64, 32, 3, 1, 1),
nn.Conv2d(32, 32, 5, 1, 2),
nn.Conv2d(32, 64, 3, 1, 1)
)
self.upscale = nn.Sequential(
nn.Conv2d(64, 3 * scale * scale, 3, 1, 1),
nn.PixelShuffle(scale)
)
def forward(self, x):
x = self.conv_input(x)
x = x + self.msrb(x)
x = self.upscale(x)
return x
class SmallModelEnhancer:
"""Enhancer using small AI models for <1GB VRAM"""
# Small model URLs
MODEL_URLS = {
'CARN': 'https://github.com/nmhkahn/CARN-pytorch/releases/download/v1.0/carn.pth',
'waifu2x-cunet': 'https://github.com/nagadomi/waifu2x/releases/download/v5.0/cunet.pth',
'FALSR-A': 'https://github.com/xiaomi-automl/FALSR/releases/download/v1.0/falsr_a.pth',
'MSRN': 'https://github.com/MIVRC/MSRN-PyTorch/releases/download/v1.0/msrn_x4.pth',
'PAN': 'https://github.com/zhaohengyuan1/PAN/releases/download/v1.0/pan_x4.pth',
'IDN': 'https://github.com/Zheng222/IDN/releases/download/v1.0/idn_x4.pth'
}
def __init__(self, model_name='CARN', device=None):
"""Initialize with small model"""
self.model_name = model_name
# Device setup
if device is None:
if torch.cuda.is_available():
self.device = torch.device('cuda')
# Limit memory for small GPUs
torch.cuda.set_per_process_memory_fraction(0.5) # Use only 50% VRAM
else:
self.device = torch.device('cpu')
else:
self.device = device
print(f"🚀 Using {model_name} on {self.device}")
# Model directory
self.model_dir = 'models_small'
os.makedirs(self.model_dir, exist_ok=True)
# Load model
self.model = None
self.load_model()
def load_model(self):
"""Load small model"""
try:
if self.model_name == 'CARN':
self.model = CARN(scale=4)
elif self.model_name == 'MSRN':
self.model = MSRN(scale=4)
else:
# Load from file
model_path = os.path.join(self.model_dir, f'{self.model_name}.pth')
if os.path.exists(model_path):
self.model = torch.load(model_path, map_location=self.device)
else:
print(f"⚠️ Model {self.model_name} not found, using CARN")
self.model = CARN(scale=4)
self.model = self.model.to(self.device)
self.model.eval()
# Convert to half precision for memory saving
if self.device.type == 'cuda':
self.model = self.model.half()
print(f"✅ Loaded {self.model_name} model")
except Exception as e:
print(f"❌ Failed to load model: {e}")
# Fallback to simple upscaling
self.model = None
def enhance_image(self, image_path: str, output_path: str = None) -> str:
"""Enhance image with small model"""
if output_path is None:
output_path = image_path.replace('.', '_enhanced.')
try:
# Load image
img = cv2.imread(image_path)
if img is None:
return image_path
# Enhance with model
if self.model is not None:
enhanced = self.model_inference(img)
else:
# Fallback to traditional upscaling
enhanced = self.traditional_upscale(img, 4)
# Save
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95])
# Clear memory
if self.device.type == 'cuda':
torch.cuda.empty_cache()
return output_path
except Exception as e:
print(f"❌ Enhancement failed: {e}")
return image_path
def model_inference(self, img):
"""Run model inference with tiling for memory efficiency"""
# Convert to tensor
img_tensor = self.img_to_tensor(img)
# Process with small tiles (128x128) for minimal VRAM
tile_size = 128
_, _, h, w = img_tensor.shape
# Output tensor
output = torch.zeros((1, 3, h * 4, w * 4), device=self.device)
# Process tiles
for y in range(0, h, tile_size):
for x in range(0, w, tile_size):
# Extract tile
y_end = min(y + tile_size, h)
x_end = min(x + tile_size, w)
tile = img_tensor[:, :, y:y_end, x:x_end]
# Enhance tile
with torch.no_grad():
if self.device.type == 'cuda':
tile = tile.half()
enhanced_tile = self.model(tile)
if self.device.type == 'cuda':
enhanced_tile = enhanced_tile.float()
# Place in output
out_y = y * 4
out_x = x * 4
out_y_end = min(out_y + enhanced_tile.shape[2], output.shape[2])
out_x_end = min(out_x + enhanced_tile.shape[3], output.shape[3])
output[:, :, out_y:out_y_end, out_x:out_x_end] = enhanced_tile[:, :, :out_y_end-out_y, :out_x_end-out_x]
# Convert back to image
return self.tensor_to_img(output)
def img_to_tensor(self, img):
"""Convert image to tensor"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
return img_tensor.to(self.device)
def tensor_to_img(self, tensor):
"""Convert tensor to image"""
img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
img = (img * 255).clip(0, 255).astype(np.uint8)
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def traditional_upscale(self, img, scale):
"""Traditional upscaling fallback"""
h, w = img.shape[:2]
new_h, new_w = h * scale, w * scale
# Use EDSR-inspired upscaling
upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
# Enhance
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1
upscaled = cv2.filter2D(upscaled, -1, kernel)
upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50)
return upscaled
# Model size comparison
MODEL_SIZES = {
'PAN': '272KB',
'IDN': '600KB',
'CARN-M': '1.6MB',
'waifu2x-upconv': '3MB',
'FALSR-A': '3MB',
'CARN': '5MB',
'MSRN': '6MB',
'SRMD': '6MB',
'waifu2x-vgg': '8MB',
'SwinIR-lightweight': '900KB',
'waifu2x-cunet': '16MB',
'EDSR-baseline': '40MB',
'ESRGAN-lite': '35MB',
'RealESRGAN-small': '65MB'
}
def list_small_models():
"""List all available small models"""
print("\n🚀 Small AI Upscaling Models (<100MB)")
print("=" * 60)
for model, size in sorted(MODEL_SIZES.items(), key=lambda x: x[1]):
print(f"{model:<25} {size:>10}")
print("\n✅ All these models work with <1GB VRAM!")
# Usage example
if __name__ == "__main__":
# List models
list_small_models()
# Use small model
enhancer = SmallModelEnhancer(model_name='CARN')
result = enhancer.enhance_image('input.jpg', 'output.jpg')