SuperRetinaDemo / inference.py
Hongyang-Li's picture
Update inference.py
e0677e4 verified
# Example code for running inference on a pre-trained model
import os
import yaml
import numpy as np
import cv2
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from model.net.super_retina import SuperRetina
from utils.image_process import *
from utils.common_util import *
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
def sigmoid(arr):
return 1. / (1 + np.exp(-arr))
class Inference(object):
def __init__(self, model_path):
self.model_path = model_path
config_path = os.path.join(model_path, 'config.yaml')
if os.path.exists(config_path):
with open(config_path) as f:
predict_config = yaml.safe_load(f)
self.config = predict_config['PREDICT']
else:
raise FileNotFoundError("Config File doesn't Exist")
self.nms_size = self.config['nms_size']
self.nms_thresh = self.config['nms_thresh']
self.scale = 8
self.knn_thresh = self.config['knn_thresh']
self.model_image_width = self.config['model_image_width']
self.model_image_height = self.config['model_image_height']
self.transformer = transforms.Compose([
transforms.Resize((self.model_image_height, self.model_image_width)),
transforms.ToTensor(),
])
self.model = self.load_model()
self.device = device
self.knn_matcher = cv2.BFMatcher(cv2.NORM_L2)
def load_model(self):
model_save_path = os.path.join(self.model_path, 'SuperRetina_released.pth')
print('Loading model from {}'.format(model_save_path))
model = SuperRetina(config=self.config,
encoder_device=device,
detector_device=device,
descriptor_device=device)
checkpoint = torch.load(model_save_path, map_location=device)
# 处理权重文件,适配模型结构
processed_weights = {}
for k, v in checkpoint['net'].items():
# 处理detector head中的DoubleConv模块权重
if any(layer in k for layer in ['dconv_up3', 'dconv_up2', 'dconv_up1']):
# 对于DoubleConv模块,需要添加.conv.部分
if '.0.' in k or '.2.' in k: # 0和2是Sequential中的卷积层索引
# 将"dconv_upX.Y.weight"转换为"dconv_upX.conv.Y.weight"
parts = k.split('.')
module_name = parts[0] # dconv_up3, dconv_up2, 或 dconv_up1
layer_idx = parts[1] # 0 或 2
weight_type = parts[2] # weight 或 bias
new_key = f"{module_name}.conv.{layer_idx}.{weight_type}"
processed_weights[new_key] = v
else:
processed_weights[k] = v
else:
processed_weights[k] = v
# 检查权重文件的结构
if all(k.startswith('encoder.') or k.startswith('desc_head.') or k.startswith('det_head.') for k in processed_weights.keys()):
# 权重文件已经有正确的模块前缀
encoder_state = {k: v for k, v in processed_weights.items() if k.startswith('encoder.')}
desc_head_state = {k: v for k, v in processed_weights.items() if k.startswith('desc_head.')}
det_head_state = {k: v for k, v in processed_weights.items() if k.startswith('det_head.')}
# 移除前缀以匹配模型结构
encoder_state = {k.replace('encoder.', ''): v for k, v in encoder_state.items()}
desc_head_state = {k.replace('desc_head.', ''): v for k, v in desc_head_state.items()}
det_head_state = {k.replace('det_head.', ''): v for k, v in det_head_state.items()}
else:
# 权重文件没有模块前缀,根据层名称分类
encoder_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['conv1a', 'conv1b', 'conv2a', 'conv2b', 'conv3a', 'conv3b', 'conv4a', 'conv4b'])}
desc_head_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['convDa', 'convDb', 'convDc', 'trans_conv'])}
det_head_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['dconv_up3', 'dconv_up2', 'dconv_up1', 'conv_last'])}
# 分别加载到对应设备上的模块
model.encoder.load_state_dict(encoder_state)
model.desc_head.load_state_dict(desc_head_state)
model.det_head.load_state_dict(det_head_state)
model.to(device)
model.eval()
return model
def load_image(self, image):
# Load the image and preprocess it
if isinstance(image, str):
print('Loading image from {}'.format(image))
image = cv2.imread(image, cv2.IMREAD_COLOR)
# --- 新增代码 Start ---
# 强制将图像缩放到 config 中定义的模型输入尺寸
# 注意 cv2.resize 参数顺序是 (width, height)
if hasattr(self, 'model_image_width') and hasattr(self, 'model_image_height'):
image = cv2.resize(image, (self.model_image_width, self.model_image_height))
# --- 新增代码 End ---
# 如果是灰度图则不额外处理,如果是RGB图像则取其中的绿色通道
if len(image.shape) == 3 and image.shape[2] == 3:
image = image[:, :, 1]
image = pre_processing(image)
image = (image * 255).astype(np.uint8)
# 打印调试信息确认尺寸
print(f"Processed image shape: {image.shape}")
return image
def model_run_pair(self, query_image, refer_image):
query_tensor = self.transformer(Image.fromarray(query_image))
refer_tensor = self.transformer(Image.fromarray(refer_image))
inputs = torch.cat((query_tensor.unsqueeze(0), refer_tensor.unsqueeze(0)))
inputs = inputs.to(self.device)
with torch.no_grad():
detector_pred, descriptor_pred = self.model(inputs)
scores = simple_nms(detector_pred, self.nms_size)
b, _, h, w = detector_pred.shape
scores = scores.reshape(-1, h, w)
keypoints = [
torch.nonzero(s > self.nms_thresh)
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
# Discard keypoints near the image borders
keypoints, scores = list(zip(*[
remove_borders(k, s, 4, h, w)
for k, s in zip(keypoints, scores)]))
keypoints = [torch.flip(k, [1]).float().data for k in keypoints]
descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu()
for k, d in zip(keypoints, descriptor_pred)]
keypoints = [k.cpu() for k in keypoints]
# 将PyTorch张量转换为NumPy数组
query_tensor = query_tensor.cpu().numpy()
refer_tensor = refer_tensor.cpu().numpy()
return keypoints, descriptors, query_tensor, refer_tensor
def match(self, keypoints, descriptors):
query_keypoints, refer_keypoints = keypoints[0], keypoints[1]
query_desc, refer_desc = descriptors[0].permute(1, 0).numpy(), descriptors[1].permute(1, 0).numpy()
# mapping keypoints to scaled keypoints
cv_kpts_query = [cv2.KeyPoint(int(i[0]), int(i[1]), 30)
for i in query_keypoints]
cv_kpts_refer = [cv2.KeyPoint(int(i[0]), int(i[1]), 30)
for i in refer_keypoints]
goodMatch = []
status = []
matches = []
try:
matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2)
for m, n in matches:
if m.distance < self.knn_thresh * n.distance:
goodMatch.append(m)
status.append(True)
else:
status.append(False)
except Exception:
pass
print(len(goodMatch))
print(len(cv_kpts_query))
print(len(cv_kpts_refer))
return goodMatch, cv_kpts_query, cv_kpts_refer, status
def compute_homography(self, goodMatch, cv_kpts_query, cv_kpts_refer):
H_m = None
inliers_num_rate = 0
mask = None
if len(goodMatch) >= 4:
src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch]
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch]
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS)
goodMatch = np.array(goodMatch)[mask.ravel() == 1]
inliers_num_rate = mask.sum() / len(mask.ravel())
print(H_m, inliers_num_rate)
return H_m, inliers_num_rate, goodMatch, mask
def align_image_pair(self, H_m, inliers_num_rate, query_image, refer_image):
if H_m is not None:
h, w = self.model_image_height, self.model_image_width
# 确保图像格式正确 - 从 (channels, height, width) 转换为 (height, width, channels)
if len(query_image.shape) == 3 and query_image.shape[0] in [1, 3]:
# 如果是 (channels, height, width) 格式,转换为 (height, width, channels)
query_image = query_image.transpose(1, 2, 0)
# 如果是单通道,去掉通道维度
if query_image.shape[2] == 1:
query_image = query_image.squeeze(2)
if len(refer_image.shape) == 3 and refer_image.shape[0] in [1, 3]:
# 如果是 (channels, height, width) 格式,转换为 (height, width, channels)
refer_image = refer_image.transpose(1, 2, 0)
# 如果是单通道,去掉通道维度
if refer_image.shape[2] == 1:
refer_image = refer_image.squeeze(2)
# 确保图像是 uint8 类型
if query_image.dtype != np.uint8:
query_image = (query_image * 255).astype(np.uint8)
if refer_image.dtype != np.uint8:
refer_image = (refer_image * 255).astype(np.uint8)
print(f"query_image shape after conversion: {query_image.shape}")
print(f"refer_image shape after conversion: {refer_image.shape}")
try:
query_align = cv2.warpPerspective(query_image, H_m, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
merged = np.zeros((h, w, 3), dtype=np.uint8)
# 处理查询图像
if len(query_align.shape) == 3:
query_align_gray = cv2.cvtColor(query_align, cv2.COLOR_BGR2GRAY)
else:
query_align_gray = query_align
# 处理参考图像
if len(refer_image.shape) == 3:
refer_gray = cv2.cvtColor(refer_image, cv2.COLOR_BGR2GRAY)
else:
refer_gray = refer_image
merged[:, :, 0] = query_align_gray
merged[:, :, 1] = refer_gray
return merged
except Exception as e:
print(f"图像对齐过程中出错: {e}")
return None
print("Matched Failed!")
return None
def draw_result(self, query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status):
def drawMatches(imageA, imageB, kpsA, kpsB, matches, status):
# initialize the output visualization image
(hA, wA) = imageA.shape[:2]
(hB, wB) = imageB.shape[:2]
print("query_image shape:", imageA.shape)
print("refer_image shape:", imageB.shape)
vis = np.zeros((max(hA, hB), wA + wB, 3), dtype="uint8")
if len(imageA.shape) == 2:
imageA = cv2.cvtColor(imageA, cv2.COLOR_GRAY2RGB)
imageB = cv2.cvtColor(imageB, cv2.COLOR_GRAY2RGB)
vis[0:hA, 0:wA] = imageA
vis[0:hB, wA:] = imageB
# loop over the matches
for match, s in zip(matches, status):
trainIdx, queryIdx = match.trainIdx, match.queryIdx
# only process the match if the keypoint was successfully
# matched
if s == 1 or s == True:
# draw the match
ptA = (int(kpsA[queryIdx].pt[0]), int(kpsA[queryIdx].pt[1]))
ptB = (int(kpsB[trainIdx].pt[0]) + wA, int(kpsB[trainIdx].pt[1]))
cv2.line(vis, ptA, ptB, (0, 255, 0), 2)
# return the visualization
return vis
query_np = np.array([kp.pt for kp in cv_kpts_query])
refer_np = np.array([kp.pt for kp in cv_kpts_refer])
refer_np[:, 0] += query_image.shape[1]
matched_image = drawMatches(query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status)
plt.figure(dpi=300)
plt.scatter(query_np[:, 0], query_np[:, 1], s=1, c='r')
plt.scatter(refer_np[:, 0], refer_np[:, 1], s=1, c='r')
plt.axis('off')
plt.title('Match Result, #goodMatch: {}'.format(np.sum(status)))
plt.imshow(cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB))
# Instead of showing the plot, we return the image as a numpy array
# Convert the plot to a numpy array
plt.tight_layout()
fig = plt.gcf()
fig.canvas.draw()
# Get the RGBA buffer from the figure
w, h = fig.canvas.get_width_height()
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
buf.shape = (h, w, 3)
plt.close()
return buf
def inference(self, query_image, refer_image):
try:
query_processed = self.load_image(query_image)
refer_processed = self.load_image(refer_image)
keypoints, descriptors, query_tensor, refer_tensor = self.model_run_pair(query_processed, refer_processed)
goodMatch, cv_kpts_query, cv_kpts_refer, status = self.match(keypoints, descriptors)
H_m, inliers_num_rate, filtered_goodMatch, mask = self.compute_homography(goodMatch, cv_kpts_query, cv_kpts_refer)
# 创建过滤后的状态列表,只包含通过单应性矩阵验证的匹配点
if mask is not None:
filtered_status = [s for s, m in zip(status, mask.ravel()) if m == 1]
else:
filtered_status = []
merged = self.align_image_pair(H_m, inliers_num_rate, query_tensor, refer_tensor)
match_show = self.draw_result(query_processed, refer_processed, cv_kpts_query, cv_kpts_refer, filtered_goodMatch, filtered_status)
return merged, match_show
except Exception as e:
print(f"推理过程中出现错误: {e}")
import traceback
traceback.print_exc()
# 返回错误信息或默认图像
h, w = self.model_image_height, self.model_image_height
error_image = np.zeros((h, w, 3), dtype=np.uint8)
cv2.putText(error_image, f"Error: {str(e)}", (10, h//2),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return error_image, error_image