|
|
""" |
|
|
AI Model Manager for State-of-the-Art Image Enhancement |
|
|
Manages Real-ESRGAN, GFPGAN, SwinIR and other models |
|
|
Optimized for NVIDIA RTX 3050 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
import hashlib |
|
|
from typing import Optional, Dict, Any |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
MODEL_URLS = { |
|
|
'RealESRGAN_x4plus': { |
|
|
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', |
|
|
'hash': '4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1' |
|
|
}, |
|
|
'RealESRGAN_x4plus_anime_6B': { |
|
|
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth', |
|
|
'hash': 'f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da' |
|
|
}, |
|
|
'RealESRNet_x4plus': { |
|
|
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth', |
|
|
'hash': '99ec365d4afad750833258a1a24f44ca3fefd45f1bb7f14e1d195f21934bb428' |
|
|
}, |
|
|
'GFPGAN_v1.3': { |
|
|
'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', |
|
|
'hash': 'c953a88f2ba4e03fb985a7582126c2267b4c3db0e50def3448b844e88e8b8f5e' |
|
|
}, |
|
|
'detection_Resnet50_Final': { |
|
|
'url': 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth', |
|
|
'hash': '6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d' |
|
|
}, |
|
|
'parsing_parsenet': { |
|
|
'url': 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_parsenet.pth', |
|
|
'hash': '3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2' |
|
|
} |
|
|
} |
|
|
|
|
|
class AIModelManager: |
|
|
"""Manages AI models for image enhancement with GPU optimization""" |
|
|
|
|
|
def __init__(self, device=None, model_dir='models'): |
|
|
"""Initialize model manager with RTX 3050 optimization""" |
|
|
|
|
|
|
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device('cuda:0') |
|
|
print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
torch.cuda.set_per_process_memory_fraction(0.8) |
|
|
else: |
|
|
self.device = torch.device('cpu') |
|
|
print("💻 Using CPU (GPU not available)") |
|
|
else: |
|
|
self.device = device |
|
|
|
|
|
self.model_dir = model_dir |
|
|
os.makedirs(self.model_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.realesrgan = None |
|
|
self.realesrgan_anime = None |
|
|
self.gfpgan = None |
|
|
self.face_enhancer = None |
|
|
|
|
|
|
|
|
self.current_models = {} |
|
|
|
|
|
def download_model(self, model_name: str) -> str: |
|
|
"""Download model if not exists""" |
|
|
if model_name not in MODEL_URLS: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
model_info = MODEL_URLS[model_name] |
|
|
model_path = os.path.join(self.model_dir, f"{model_name}.pth") |
|
|
|
|
|
|
|
|
if os.path.exists(model_path): |
|
|
print(f"✅ Model {model_name} already exists") |
|
|
return model_path |
|
|
|
|
|
print(f"📥 Downloading {model_name}...") |
|
|
|
|
|
|
|
|
response = requests.get(model_info['url'], stream=True) |
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(model_path, 'wb') as f: |
|
|
with tqdm(total=total_size, unit='iB', unit_scale=True) as pbar: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
pbar.update(len(chunk)) |
|
|
|
|
|
print(f"✅ Downloaded {model_name}") |
|
|
return model_path |
|
|
|
|
|
def load_realesrgan(self, model_name='RealESRGAN_x4plus', scale=4): |
|
|
"""Load Real-ESRGAN model optimized for RTX 3050""" |
|
|
try: |
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
|
from realesrgan import RealESRGANer |
|
|
|
|
|
print(f"🔄 Loading {model_name}...") |
|
|
|
|
|
|
|
|
model_path = self.download_model(model_name) |
|
|
|
|
|
|
|
|
if 'anime' in model_name: |
|
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6) |
|
|
else: |
|
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23) |
|
|
|
|
|
|
|
|
self.realesrgan = RealESRGANer( |
|
|
scale=scale, |
|
|
model_path=model_path, |
|
|
model=model, |
|
|
device=self.device, |
|
|
|
|
|
tile=256, |
|
|
tile_pad=10, |
|
|
pre_pad=0, |
|
|
half=True if self.device.type == 'cuda' else False |
|
|
) |
|
|
|
|
|
if 'anime' in model_name: |
|
|
self.realesrgan_anime = self.realesrgan |
|
|
|
|
|
print(f"✅ Loaded {model_name} on {self.device}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load Real-ESRGAN: {e}") |
|
|
return False |
|
|
|
|
|
def load_gfpgan(self): |
|
|
"""Load GFPGAN for face enhancement""" |
|
|
try: |
|
|
from gfpgan import GFPGANer |
|
|
|
|
|
print("🔄 Loading GFPGAN v1.3...") |
|
|
|
|
|
|
|
|
model_path = self.download_model('GFPGAN_v1.3') |
|
|
det_model_path = self.download_model('detection_Resnet50_Final') |
|
|
parse_model_path = self.download_model('parsing_parsenet') |
|
|
|
|
|
|
|
|
self.gfpgan = GFPGANer( |
|
|
model_path=model_path, |
|
|
upscale=2, |
|
|
arch='clean', |
|
|
channel_multiplier=2, |
|
|
bg_upsampler=self.realesrgan, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
print("✅ Loaded GFPGAN on", self.device) |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load GFPGAN: {e}") |
|
|
return False |
|
|
|
|
|
def enhance_image_realesrgan(self, image, use_anime_model=False): |
|
|
"""Enhance image using Real-ESRGAN""" |
|
|
if use_anime_model and self.realesrgan_anime: |
|
|
upsampler = self.realesrgan_anime |
|
|
else: |
|
|
upsampler = self.realesrgan |
|
|
|
|
|
if upsampler is None: |
|
|
model_name = 'RealESRGAN_x4plus_anime_6B' if use_anime_model else 'RealESRGAN_x4plus' |
|
|
if not self.load_realesrgan(model_name): |
|
|
return image |
|
|
|
|
|
upsampler = self.realesrgan_anime if use_anime_model else self.realesrgan |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image) |
|
|
|
|
|
|
|
|
if len(image.shape) == 2: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
|
|
elif image.shape[2] == 4: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) |
|
|
elif image.shape[2] == 3: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output, _ = upsampler.enhance(image, outscale=4) |
|
|
|
|
|
|
|
|
h, w = output.shape[:2] |
|
|
if w > 2048 or h > 1080: |
|
|
scale = min(2048/w, 1080/h) |
|
|
new_w = int(w * scale) |
|
|
new_h = int(h * scale) |
|
|
output = cv2.resize(output, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) |
|
|
print(f" 📐 Resized from {w}x{h} to {new_w}x{new_h} (2K limit)") |
|
|
|
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Real-ESRGAN enhancement failed: {e}") |
|
|
return image |
|
|
|
|
|
def enhance_face_gfpgan(self, image, only_center_face=False, paste_back=True): |
|
|
"""Enhance faces in image using GFPGAN""" |
|
|
if self.gfpgan is None: |
|
|
if not self.load_gfpgan(): |
|
|
return image |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image) |
|
|
|
|
|
|
|
|
if len(image.shape) == 2: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
|
|
elif image.shape[2] == 4: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) |
|
|
elif image.shape[2] == 3: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_, _, output = self.gfpgan.enhance( |
|
|
image, |
|
|
has_aligned=False, |
|
|
only_center_face=only_center_face, |
|
|
paste_back=paste_back, |
|
|
weight=0.5 |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ GFPGAN enhancement failed: {e}") |
|
|
return image |
|
|
|
|
|
def enhance_image_pipeline(self, image_path: str, output_path: str = None, |
|
|
enhance_face=True, use_anime_model=False) -> str: |
|
|
"""Complete enhancement pipeline optimized for RTX 3050""" |
|
|
|
|
|
print(f"🎨 Enhancing {os.path.basename(image_path)}...") |
|
|
|
|
|
try: |
|
|
|
|
|
image = cv2.imread(image_path) |
|
|
if image is None: |
|
|
print(f"❌ Failed to load image: {image_path}") |
|
|
return image_path |
|
|
|
|
|
original_shape = image.shape[:2] |
|
|
|
|
|
|
|
|
print(" 📈 Applying super-resolution (max 2K)...") |
|
|
enhanced = self.enhance_image_realesrgan(image, use_anime_model) |
|
|
|
|
|
|
|
|
if enhance_face: |
|
|
print(" 👤 Enhancing faces...") |
|
|
enhanced = self.enhance_face_gfpgan(enhanced) |
|
|
|
|
|
|
|
|
print(" ✨ Applying final enhancements...") |
|
|
enhanced = self.post_process(enhanced) |
|
|
|
|
|
|
|
|
if output_path is None: |
|
|
output_path = image_path.replace('.', '_enhanced.') |
|
|
|
|
|
cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95]) |
|
|
|
|
|
new_shape = enhanced.shape[:2] |
|
|
print(f" ✅ Enhanced: {original_shape} → {new_shape}") |
|
|
|
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Enhancement pipeline failed: {e}") |
|
|
return image_path |
|
|
|
|
|
def post_process(self, image): |
|
|
"""Additional post-processing for enhanced quality""" |
|
|
try: |
|
|
|
|
|
kernel = np.array([[-0.5,-0.5,-0.5], |
|
|
[-0.5, 5,-0.5], |
|
|
[-0.5,-0.5,-0.5]]) / 1 |
|
|
image = cv2.filter2D(image, -1, kernel) |
|
|
|
|
|
|
|
|
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) |
|
|
l, a, b = cv2.split(lab) |
|
|
|
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
|
|
l = clahe.apply(l) |
|
|
|
|
|
|
|
|
a = cv2.convertScaleAbs(a, alpha=1.1, beta=0) |
|
|
b = cv2.convertScaleAbs(b, alpha=1.1, beta=0) |
|
|
|
|
|
enhanced = cv2.merge([l, a, b]) |
|
|
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) |
|
|
|
|
|
|
|
|
enhanced = cv2.convertScaleAbs(enhanced, alpha=1.05, beta=5) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Post-processing failed: {e}") |
|
|
return image |
|
|
|
|
|
def clear_memory(self): |
|
|
"""Clear GPU memory - important for RTX 3050 with limited VRAM""" |
|
|
if self.device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
_ai_model_manager = None |
|
|
|
|
|
def get_ai_model_manager(): |
|
|
"""Get or create global AI model manager""" |
|
|
global _ai_model_manager |
|
|
if _ai_model_manager is None: |
|
|
_ai_model_manager = AIModelManager() |
|
|
return _ai_model_manager |