# Example code for running inference on a pre-trained model import os import yaml import numpy as np import cv2 import torch from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from model.net.super_retina import SuperRetina from utils.image_process import * from utils.common_util import * device = torch.device('cuda' if torch.cuda.is_available() else "cpu") def sigmoid(arr): return 1. / (1 + np.exp(-arr)) class Inference(object): def __init__(self, model_path): self.model_path = model_path config_path = os.path.join(model_path, 'config.yaml') if os.path.exists(config_path): with open(config_path) as f: predict_config = yaml.safe_load(f) self.config = predict_config['PREDICT'] else: raise FileNotFoundError("Config File doesn't Exist") self.nms_size = self.config['nms_size'] self.nms_thresh = self.config['nms_thresh'] self.scale = 8 self.knn_thresh = self.config['knn_thresh'] self.model_image_width = self.config['model_image_width'] self.model_image_height = self.config['model_image_height'] self.transformer = transforms.Compose([ transforms.Resize((self.model_image_height, self.model_image_width)), transforms.ToTensor(), ]) self.model = self.load_model() self.device = device self.knn_matcher = cv2.BFMatcher(cv2.NORM_L2) def load_model(self): model_save_path = os.path.join(self.model_path, 'SuperRetina_released.pth') print('Loading model from {}'.format(model_save_path)) model = SuperRetina(config=self.config, encoder_device=device, detector_device=device, descriptor_device=device) checkpoint = torch.load(model_save_path, map_location=device) # 处理权重文件,适配模型结构 processed_weights = {} for k, v in checkpoint['net'].items(): # 处理detector head中的DoubleConv模块权重 if any(layer in k for layer in ['dconv_up3', 'dconv_up2', 'dconv_up1']): # 对于DoubleConv模块,需要添加.conv.部分 if '.0.' in k or '.2.' in k: # 0和2是Sequential中的卷积层索引 # 将"dconv_upX.Y.weight"转换为"dconv_upX.conv.Y.weight" parts = k.split('.') module_name = parts[0] # dconv_up3, dconv_up2, 或 dconv_up1 layer_idx = parts[1] # 0 或 2 weight_type = parts[2] # weight 或 bias new_key = f"{module_name}.conv.{layer_idx}.{weight_type}" processed_weights[new_key] = v else: processed_weights[k] = v else: processed_weights[k] = v # 检查权重文件的结构 if all(k.startswith('encoder.') or k.startswith('desc_head.') or k.startswith('det_head.') for k in processed_weights.keys()): # 权重文件已经有正确的模块前缀 encoder_state = {k: v for k, v in processed_weights.items() if k.startswith('encoder.')} desc_head_state = {k: v for k, v in processed_weights.items() if k.startswith('desc_head.')} det_head_state = {k: v for k, v in processed_weights.items() if k.startswith('det_head.')} # 移除前缀以匹配模型结构 encoder_state = {k.replace('encoder.', ''): v for k, v in encoder_state.items()} desc_head_state = {k.replace('desc_head.', ''): v for k, v in desc_head_state.items()} det_head_state = {k.replace('det_head.', ''): v for k, v in det_head_state.items()} else: # 权重文件没有模块前缀,根据层名称分类 encoder_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['conv1a', 'conv1b', 'conv2a', 'conv2b', 'conv3a', 'conv3b', 'conv4a', 'conv4b'])} desc_head_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['convDa', 'convDb', 'convDc', 'trans_conv'])} det_head_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['dconv_up3', 'dconv_up2', 'dconv_up1', 'conv_last'])} # 分别加载到对应设备上的模块 model.encoder.load_state_dict(encoder_state) model.desc_head.load_state_dict(desc_head_state) model.det_head.load_state_dict(det_head_state) model.to(device) model.eval() return model def load_image(self, image): # Load the image and preprocess it if isinstance(image, str): print('Loading image from {}'.format(image)) image = cv2.imread(image, cv2.IMREAD_COLOR) # --- 新增代码 Start --- # 强制将图像缩放到 config 中定义的模型输入尺寸 # 注意 cv2.resize 参数顺序是 (width, height) if hasattr(self, 'model_image_width') and hasattr(self, 'model_image_height'): image = cv2.resize(image, (self.model_image_width, self.model_image_height)) # --- 新增代码 End --- # 如果是灰度图则不额外处理,如果是RGB图像则取其中的绿色通道 if len(image.shape) == 3 and image.shape[2] == 3: image = image[:, :, 1] image = pre_processing(image) image = (image * 255).astype(np.uint8) # 打印调试信息确认尺寸 print(f"Processed image shape: {image.shape}") return image def model_run_pair(self, query_image, refer_image): query_tensor = self.transformer(Image.fromarray(query_image)) refer_tensor = self.transformer(Image.fromarray(refer_image)) inputs = torch.cat((query_tensor.unsqueeze(0), refer_tensor.unsqueeze(0))) inputs = inputs.to(self.device) with torch.no_grad(): detector_pred, descriptor_pred = self.model(inputs) scores = simple_nms(detector_pred, self.nms_size) b, _, h, w = detector_pred.shape scores = scores.reshape(-1, h, w) keypoints = [ torch.nonzero(s > self.nms_thresh) for s in scores] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders keypoints, scores = list(zip(*[ remove_borders(k, s, 4, h, w) for k, s in zip(keypoints, scores)])) keypoints = [torch.flip(k, [1]).float().data for k in keypoints] descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu() for k, d in zip(keypoints, descriptor_pred)] keypoints = [k.cpu() for k in keypoints] # 将PyTorch张量转换为NumPy数组 query_tensor = query_tensor.cpu().numpy() refer_tensor = refer_tensor.cpu().numpy() return keypoints, descriptors, query_tensor, refer_tensor def match(self, keypoints, descriptors): query_keypoints, refer_keypoints = keypoints[0], keypoints[1] query_desc, refer_desc = descriptors[0].permute(1, 0).numpy(), descriptors[1].permute(1, 0).numpy() # mapping keypoints to scaled keypoints cv_kpts_query = [cv2.KeyPoint(int(i[0]), int(i[1]), 30) for i in query_keypoints] cv_kpts_refer = [cv2.KeyPoint(int(i[0]), int(i[1]), 30) for i in refer_keypoints] goodMatch = [] status = [] matches = [] try: matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2) for m, n in matches: if m.distance < self.knn_thresh * n.distance: goodMatch.append(m) status.append(True) else: status.append(False) except Exception: pass print(len(goodMatch)) print(len(cv_kpts_query)) print(len(cv_kpts_refer)) return goodMatch, cv_kpts_query, cv_kpts_refer, status def compute_homography(self, goodMatch, cv_kpts_query, cv_kpts_refer): H_m = None inliers_num_rate = 0 mask = None if len(goodMatch) >= 4: src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] src_pts = np.float32(src_pts).reshape(-1, 1, 2) dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS) goodMatch = np.array(goodMatch)[mask.ravel() == 1] inliers_num_rate = mask.sum() / len(mask.ravel()) print(H_m, inliers_num_rate) return H_m, inliers_num_rate, goodMatch, mask def align_image_pair(self, H_m, inliers_num_rate, query_image, refer_image): if H_m is not None: h, w = self.model_image_height, self.model_image_width # 确保图像格式正确 - 从 (channels, height, width) 转换为 (height, width, channels) if len(query_image.shape) == 3 and query_image.shape[0] in [1, 3]: # 如果是 (channels, height, width) 格式,转换为 (height, width, channels) query_image = query_image.transpose(1, 2, 0) # 如果是单通道,去掉通道维度 if query_image.shape[2] == 1: query_image = query_image.squeeze(2) if len(refer_image.shape) == 3 and refer_image.shape[0] in [1, 3]: # 如果是 (channels, height, width) 格式,转换为 (height, width, channels) refer_image = refer_image.transpose(1, 2, 0) # 如果是单通道,去掉通道维度 if refer_image.shape[2] == 1: refer_image = refer_image.squeeze(2) # 确保图像是 uint8 类型 if query_image.dtype != np.uint8: query_image = (query_image * 255).astype(np.uint8) if refer_image.dtype != np.uint8: refer_image = (refer_image * 255).astype(np.uint8) print(f"query_image shape after conversion: {query_image.shape}") print(f"refer_image shape after conversion: {refer_image.shape}") try: query_align = cv2.warpPerspective(query_image, H_m, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) merged = np.zeros((h, w, 3), dtype=np.uint8) # 处理查询图像 if len(query_align.shape) == 3: query_align_gray = cv2.cvtColor(query_align, cv2.COLOR_BGR2GRAY) else: query_align_gray = query_align # 处理参考图像 if len(refer_image.shape) == 3: refer_gray = cv2.cvtColor(refer_image, cv2.COLOR_BGR2GRAY) else: refer_gray = refer_image merged[:, :, 0] = query_align_gray merged[:, :, 1] = refer_gray return merged except Exception as e: print(f"图像对齐过程中出错: {e}") return None print("Matched Failed!") return None def draw_result(self, query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status): def drawMatches(imageA, imageB, kpsA, kpsB, matches, status): # initialize the output visualization image (hA, wA) = imageA.shape[:2] (hB, wB) = imageB.shape[:2] print("query_image shape:", imageA.shape) print("refer_image shape:", imageB.shape) vis = np.zeros((max(hA, hB), wA + wB, 3), dtype="uint8") if len(imageA.shape) == 2: imageA = cv2.cvtColor(imageA, cv2.COLOR_GRAY2RGB) imageB = cv2.cvtColor(imageB, cv2.COLOR_GRAY2RGB) vis[0:hA, 0:wA] = imageA vis[0:hB, wA:] = imageB # loop over the matches for match, s in zip(matches, status): trainIdx, queryIdx = match.trainIdx, match.queryIdx # only process the match if the keypoint was successfully # matched if s == 1 or s == True: # draw the match ptA = (int(kpsA[queryIdx].pt[0]), int(kpsA[queryIdx].pt[1])) ptB = (int(kpsB[trainIdx].pt[0]) + wA, int(kpsB[trainIdx].pt[1])) cv2.line(vis, ptA, ptB, (0, 255, 0), 2) # return the visualization return vis query_np = np.array([kp.pt for kp in cv_kpts_query]) refer_np = np.array([kp.pt for kp in cv_kpts_refer]) refer_np[:, 0] += query_image.shape[1] matched_image = drawMatches(query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status) plt.figure(dpi=300) plt.scatter(query_np[:, 0], query_np[:, 1], s=1, c='r') plt.scatter(refer_np[:, 0], refer_np[:, 1], s=1, c='r') plt.axis('off') plt.title('Match Result, #goodMatch: {}'.format(np.sum(status))) plt.imshow(cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB)) # Instead of showing the plot, we return the image as a numpy array # Convert the plot to a numpy array plt.tight_layout() fig = plt.gcf() fig.canvas.draw() # Get the RGBA buffer from the figure w, h = fig.canvas.get_width_height() buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) buf.shape = (h, w, 3) plt.close() return buf def inference(self, query_image, refer_image): try: query_processed = self.load_image(query_image) refer_processed = self.load_image(refer_image) keypoints, descriptors, query_tensor, refer_tensor = self.model_run_pair(query_processed, refer_processed) goodMatch, cv_kpts_query, cv_kpts_refer, status = self.match(keypoints, descriptors) H_m, inliers_num_rate, filtered_goodMatch, mask = self.compute_homography(goodMatch, cv_kpts_query, cv_kpts_refer) # 创建过滤后的状态列表,只包含通过单应性矩阵验证的匹配点 if mask is not None: filtered_status = [s for s, m in zip(status, mask.ravel()) if m == 1] else: filtered_status = [] merged = self.align_image_pair(H_m, inliers_num_rate, query_tensor, refer_tensor) match_show = self.draw_result(query_processed, refer_processed, cv_kpts_query, cv_kpts_refer, filtered_goodMatch, filtered_status) return merged, match_show except Exception as e: print(f"推理过程中出现错误: {e}") import traceback traceback.print_exc() # 返回错误信息或默认图像 h, w = self.model_image_height, self.model_image_height error_image = np.zeros((h, w, 3), dtype=np.uint8) cv2.putText(error_image, f"Error: {str(e)}", (10, h//2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) return error_image, error_image