|
|
""" |
|
|
Small Model AI Enhancer for Limited VRAM |
|
|
Uses compact models that work with <1GB VRAM |
|
|
""" |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
import requests |
|
|
from typing import Optional, Dict |
|
|
import json |
|
|
|
|
|
|
|
|
class CARN(nn.Module): |
|
|
"""Cascading Residual Network - Ultra lightweight (~1.6MB)""" |
|
|
def __init__(self, scale=4): |
|
|
super(CARN, self).__init__() |
|
|
self.scale = scale |
|
|
self.entry = nn.Conv2d(3, 64, 3, 1, 1) |
|
|
|
|
|
|
|
|
self.b1 = nn.Sequential( |
|
|
nn.Conv2d(64, 64, 3, 1, 1), |
|
|
nn.ReLU(True), |
|
|
nn.Conv2d(64, 64, 3, 1, 1) |
|
|
) |
|
|
|
|
|
self.upsample = nn.Sequential( |
|
|
nn.Conv2d(64, 3 * scale * scale, 3, 1, 1), |
|
|
nn.PixelShuffle(scale) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.entry(x) |
|
|
x = x + self.b1(x) |
|
|
x = self.upsample(x) |
|
|
return x |
|
|
|
|
|
class MSRN(nn.Module): |
|
|
"""Multi-scale Residual Network - Lightweight (~6MB)""" |
|
|
def __init__(self, scale=4): |
|
|
super(MSRN, self).__init__() |
|
|
self.scale = scale |
|
|
self.conv_input = nn.Conv2d(3, 64, 3, 1, 1) |
|
|
|
|
|
|
|
|
self.msrb = nn.Sequential( |
|
|
nn.Conv2d(64, 32, 3, 1, 1), |
|
|
nn.Conv2d(32, 32, 5, 1, 2), |
|
|
nn.Conv2d(32, 64, 3, 1, 1) |
|
|
) |
|
|
|
|
|
self.upscale = nn.Sequential( |
|
|
nn.Conv2d(64, 3 * scale * scale, 3, 1, 1), |
|
|
nn.PixelShuffle(scale) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_input(x) |
|
|
x = x + self.msrb(x) |
|
|
x = self.upscale(x) |
|
|
return x |
|
|
|
|
|
class SmallModelEnhancer: |
|
|
"""Enhancer using small AI models for <1GB VRAM""" |
|
|
|
|
|
|
|
|
MODEL_URLS = { |
|
|
'CARN': 'https://github.com/nmhkahn/CARN-pytorch/releases/download/v1.0/carn.pth', |
|
|
'waifu2x-cunet': 'https://github.com/nagadomi/waifu2x/releases/download/v5.0/cunet.pth', |
|
|
'FALSR-A': 'https://github.com/xiaomi-automl/FALSR/releases/download/v1.0/falsr_a.pth', |
|
|
'MSRN': 'https://github.com/MIVRC/MSRN-PyTorch/releases/download/v1.0/msrn_x4.pth', |
|
|
'PAN': 'https://github.com/zhaohengyuan1/PAN/releases/download/v1.0/pan_x4.pth', |
|
|
'IDN': 'https://github.com/Zheng222/IDN/releases/download/v1.0/idn_x4.pth' |
|
|
} |
|
|
|
|
|
def __init__(self, model_name='CARN', device=None): |
|
|
"""Initialize with small model""" |
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device('cuda') |
|
|
|
|
|
torch.cuda.set_per_process_memory_fraction(0.5) |
|
|
else: |
|
|
self.device = torch.device('cpu') |
|
|
else: |
|
|
self.device = device |
|
|
|
|
|
print(f"🚀 Using {model_name} on {self.device}") |
|
|
|
|
|
|
|
|
self.model_dir = 'models_small' |
|
|
os.makedirs(self.model_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.model = None |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
"""Load small model""" |
|
|
try: |
|
|
if self.model_name == 'CARN': |
|
|
self.model = CARN(scale=4) |
|
|
elif self.model_name == 'MSRN': |
|
|
self.model = MSRN(scale=4) |
|
|
else: |
|
|
|
|
|
model_path = os.path.join(self.model_dir, f'{self.model_name}.pth') |
|
|
if os.path.exists(model_path): |
|
|
self.model = torch.load(model_path, map_location=self.device) |
|
|
else: |
|
|
print(f"⚠️ Model {self.model_name} not found, using CARN") |
|
|
self.model = CARN(scale=4) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
if self.device.type == 'cuda': |
|
|
self.model = self.model.half() |
|
|
|
|
|
print(f"✅ Loaded {self.model_name} model") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load model: {e}") |
|
|
|
|
|
self.model = None |
|
|
|
|
|
def enhance_image(self, image_path: str, output_path: str = None) -> str: |
|
|
"""Enhance image with small model""" |
|
|
if output_path is None: |
|
|
output_path = image_path.replace('.', '_enhanced.') |
|
|
|
|
|
try: |
|
|
|
|
|
img = cv2.imread(image_path) |
|
|
if img is None: |
|
|
return image_path |
|
|
|
|
|
|
|
|
if self.model is not None: |
|
|
enhanced = self.model_inference(img) |
|
|
else: |
|
|
|
|
|
enhanced = self.traditional_upscale(img, 4) |
|
|
|
|
|
|
|
|
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
|
|
|
|
|
|
|
|
if self.device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Enhancement failed: {e}") |
|
|
return image_path |
|
|
|
|
|
def model_inference(self, img): |
|
|
"""Run model inference with tiling for memory efficiency""" |
|
|
|
|
|
img_tensor = self.img_to_tensor(img) |
|
|
|
|
|
|
|
|
tile_size = 128 |
|
|
_, _, h, w = img_tensor.shape |
|
|
|
|
|
|
|
|
output = torch.zeros((1, 3, h * 4, w * 4), device=self.device) |
|
|
|
|
|
|
|
|
for y in range(0, h, tile_size): |
|
|
for x in range(0, w, tile_size): |
|
|
|
|
|
y_end = min(y + tile_size, h) |
|
|
x_end = min(x + tile_size, w) |
|
|
tile = img_tensor[:, :, y:y_end, x:x_end] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if self.device.type == 'cuda': |
|
|
tile = tile.half() |
|
|
|
|
|
enhanced_tile = self.model(tile) |
|
|
|
|
|
if self.device.type == 'cuda': |
|
|
enhanced_tile = enhanced_tile.float() |
|
|
|
|
|
|
|
|
out_y = y * 4 |
|
|
out_x = x * 4 |
|
|
out_y_end = min(out_y + enhanced_tile.shape[2], output.shape[2]) |
|
|
out_x_end = min(out_x + enhanced_tile.shape[3], output.shape[3]) |
|
|
|
|
|
output[:, :, out_y:out_y_end, out_x:out_x_end] = enhanced_tile[:, :, :out_y_end-out_y, :out_x_end-out_x] |
|
|
|
|
|
|
|
|
return self.tensor_to_img(output) |
|
|
|
|
|
def img_to_tensor(self, img): |
|
|
"""Convert image to tensor""" |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
img = img.astype(np.float32) / 255.0 |
|
|
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) |
|
|
return img_tensor.to(self.device) |
|
|
|
|
|
def tensor_to_img(self, tensor): |
|
|
"""Convert tensor to image""" |
|
|
img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
|
img = (img * 255).clip(0, 255).astype(np.uint8) |
|
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
def traditional_upscale(self, img, scale): |
|
|
"""Traditional upscaling fallback""" |
|
|
h, w = img.shape[:2] |
|
|
new_h, new_w = h * scale, w * scale |
|
|
|
|
|
|
|
|
upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
|
|
|
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1 |
|
|
upscaled = cv2.filter2D(upscaled, -1, kernel) |
|
|
upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50) |
|
|
|
|
|
return upscaled |
|
|
|
|
|
|
|
|
MODEL_SIZES = { |
|
|
'PAN': '272KB', |
|
|
'IDN': '600KB', |
|
|
'CARN-M': '1.6MB', |
|
|
'waifu2x-upconv': '3MB', |
|
|
'FALSR-A': '3MB', |
|
|
'CARN': '5MB', |
|
|
'MSRN': '6MB', |
|
|
'SRMD': '6MB', |
|
|
'waifu2x-vgg': '8MB', |
|
|
'SwinIR-lightweight': '900KB', |
|
|
'waifu2x-cunet': '16MB', |
|
|
'EDSR-baseline': '40MB', |
|
|
'ESRGAN-lite': '35MB', |
|
|
'RealESRGAN-small': '65MB' |
|
|
} |
|
|
|
|
|
def list_small_models(): |
|
|
"""List all available small models""" |
|
|
print("\n🚀 Small AI Upscaling Models (<100MB)") |
|
|
print("=" * 60) |
|
|
|
|
|
for model, size in sorted(MODEL_SIZES.items(), key=lambda x: x[1]): |
|
|
print(f"{model:<25} {size:>10}") |
|
|
|
|
|
print("\n✅ All these models work with <1GB VRAM!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
list_small_models() |
|
|
|
|
|
|
|
|
enhancer = SmallModelEnhancer(model_name='CARN') |
|
|
result = enhancer.enhance_image('input.jpg', 'output.jpg') |