Spaces:
Sleeping
Sleeping
| # Example code for running inference on a pre-trained model | |
| import os | |
| import yaml | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from model.net.super_retina import SuperRetina | |
| from utils.image_process import * | |
| from utils.common_util import * | |
| device = torch.device('cuda' if torch.cuda.is_available() else "cpu") | |
| def sigmoid(arr): | |
| return 1. / (1 + np.exp(-arr)) | |
| class Inference(object): | |
| def __init__(self, model_path): | |
| self.model_path = model_path | |
| config_path = os.path.join(model_path, 'config.yaml') | |
| if os.path.exists(config_path): | |
| with open(config_path) as f: | |
| predict_config = yaml.safe_load(f) | |
| self.config = predict_config['PREDICT'] | |
| else: | |
| raise FileNotFoundError("Config File doesn't Exist") | |
| self.nms_size = self.config['nms_size'] | |
| self.nms_thresh = self.config['nms_thresh'] | |
| self.scale = 8 | |
| self.knn_thresh = self.config['knn_thresh'] | |
| self.model_image_width = self.config['model_image_width'] | |
| self.model_image_height = self.config['model_image_height'] | |
| self.transformer = transforms.Compose([ | |
| transforms.Resize((self.model_image_height, self.model_image_width)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.model = self.load_model() | |
| self.device = device | |
| self.knn_matcher = cv2.BFMatcher(cv2.NORM_L2) | |
| def load_model(self): | |
| model_save_path = os.path.join(self.model_path, 'SuperRetina_released.pth') | |
| print('Loading model from {}'.format(model_save_path)) | |
| model = SuperRetina(config=self.config, | |
| encoder_device=device, | |
| detector_device=device, | |
| descriptor_device=device) | |
| checkpoint = torch.load(model_save_path, map_location=device) | |
| # 处理权重文件,适配模型结构 | |
| processed_weights = {} | |
| for k, v in checkpoint['net'].items(): | |
| # 处理detector head中的DoubleConv模块权重 | |
| if any(layer in k for layer in ['dconv_up3', 'dconv_up2', 'dconv_up1']): | |
| # 对于DoubleConv模块,需要添加.conv.部分 | |
| if '.0.' in k or '.2.' in k: # 0和2是Sequential中的卷积层索引 | |
| # 将"dconv_upX.Y.weight"转换为"dconv_upX.conv.Y.weight" | |
| parts = k.split('.') | |
| module_name = parts[0] # dconv_up3, dconv_up2, 或 dconv_up1 | |
| layer_idx = parts[1] # 0 或 2 | |
| weight_type = parts[2] # weight 或 bias | |
| new_key = f"{module_name}.conv.{layer_idx}.{weight_type}" | |
| processed_weights[new_key] = v | |
| else: | |
| processed_weights[k] = v | |
| else: | |
| processed_weights[k] = v | |
| # 检查权重文件的结构 | |
| if all(k.startswith('encoder.') or k.startswith('desc_head.') or k.startswith('det_head.') for k in processed_weights.keys()): | |
| # 权重文件已经有正确的模块前缀 | |
| encoder_state = {k: v for k, v in processed_weights.items() if k.startswith('encoder.')} | |
| desc_head_state = {k: v for k, v in processed_weights.items() if k.startswith('desc_head.')} | |
| det_head_state = {k: v for k, v in processed_weights.items() if k.startswith('det_head.')} | |
| # 移除前缀以匹配模型结构 | |
| encoder_state = {k.replace('encoder.', ''): v for k, v in encoder_state.items()} | |
| desc_head_state = {k.replace('desc_head.', ''): v for k, v in desc_head_state.items()} | |
| det_head_state = {k.replace('det_head.', ''): v for k, v in det_head_state.items()} | |
| else: | |
| # 权重文件没有模块前缀,根据层名称分类 | |
| encoder_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['conv1a', 'conv1b', 'conv2a', 'conv2b', 'conv3a', 'conv3b', 'conv4a', 'conv4b'])} | |
| desc_head_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['convDa', 'convDb', 'convDc', 'trans_conv'])} | |
| det_head_state = {k: v for k, v in processed_weights.items() if any(layer in k for layer in ['dconv_up3', 'dconv_up2', 'dconv_up1', 'conv_last'])} | |
| # 分别加载到对应设备上的模块 | |
| model.encoder.load_state_dict(encoder_state) | |
| model.desc_head.load_state_dict(desc_head_state) | |
| model.det_head.load_state_dict(det_head_state) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def load_image(self, image): | |
| # Load the image and preprocess it | |
| if isinstance(image, str): | |
| print('Loading image from {}'.format(image)) | |
| image = cv2.imread(image, cv2.IMREAD_COLOR) | |
| # --- 新增代码 Start --- | |
| # 强制将图像缩放到 config 中定义的模型输入尺寸 | |
| # 注意 cv2.resize 参数顺序是 (width, height) | |
| if hasattr(self, 'model_image_width') and hasattr(self, 'model_image_height'): | |
| image = cv2.resize(image, (self.model_image_width, self.model_image_height)) | |
| # --- 新增代码 End --- | |
| # 如果是灰度图则不额外处理,如果是RGB图像则取其中的绿色通道 | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| image = image[:, :, 1] | |
| image = pre_processing(image) | |
| image = (image * 255).astype(np.uint8) | |
| # 打印调试信息确认尺寸 | |
| print(f"Processed image shape: {image.shape}") | |
| return image | |
| def model_run_pair(self, query_image, refer_image): | |
| query_tensor = self.transformer(Image.fromarray(query_image)) | |
| refer_tensor = self.transformer(Image.fromarray(refer_image)) | |
| inputs = torch.cat((query_tensor.unsqueeze(0), refer_tensor.unsqueeze(0))) | |
| inputs = inputs.to(self.device) | |
| with torch.no_grad(): | |
| detector_pred, descriptor_pred = self.model(inputs) | |
| scores = simple_nms(detector_pred, self.nms_size) | |
| b, _, h, w = detector_pred.shape | |
| scores = scores.reshape(-1, h, w) | |
| keypoints = [ | |
| torch.nonzero(s > self.nms_thresh) | |
| for s in scores] | |
| scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] | |
| # Discard keypoints near the image borders | |
| keypoints, scores = list(zip(*[ | |
| remove_borders(k, s, 4, h, w) | |
| for k, s in zip(keypoints, scores)])) | |
| keypoints = [torch.flip(k, [1]).float().data for k in keypoints] | |
| descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu() | |
| for k, d in zip(keypoints, descriptor_pred)] | |
| keypoints = [k.cpu() for k in keypoints] | |
| # 将PyTorch张量转换为NumPy数组 | |
| query_tensor = query_tensor.cpu().numpy() | |
| refer_tensor = refer_tensor.cpu().numpy() | |
| return keypoints, descriptors, query_tensor, refer_tensor | |
| def match(self, keypoints, descriptors): | |
| query_keypoints, refer_keypoints = keypoints[0], keypoints[1] | |
| query_desc, refer_desc = descriptors[0].permute(1, 0).numpy(), descriptors[1].permute(1, 0).numpy() | |
| # mapping keypoints to scaled keypoints | |
| cv_kpts_query = [cv2.KeyPoint(int(i[0]), int(i[1]), 30) | |
| for i in query_keypoints] | |
| cv_kpts_refer = [cv2.KeyPoint(int(i[0]), int(i[1]), 30) | |
| for i in refer_keypoints] | |
| goodMatch = [] | |
| status = [] | |
| matches = [] | |
| try: | |
| matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2) | |
| for m, n in matches: | |
| if m.distance < self.knn_thresh * n.distance: | |
| goodMatch.append(m) | |
| status.append(True) | |
| else: | |
| status.append(False) | |
| except Exception: | |
| pass | |
| print(len(goodMatch)) | |
| print(len(cv_kpts_query)) | |
| print(len(cv_kpts_refer)) | |
| return goodMatch, cv_kpts_query, cv_kpts_refer, status | |
| def compute_homography(self, goodMatch, cv_kpts_query, cv_kpts_refer): | |
| H_m = None | |
| inliers_num_rate = 0 | |
| mask = None | |
| if len(goodMatch) >= 4: | |
| src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] | |
| src_pts = np.float32(src_pts).reshape(-1, 1, 2) | |
| dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] | |
| dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) | |
| H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS) | |
| goodMatch = np.array(goodMatch)[mask.ravel() == 1] | |
| inliers_num_rate = mask.sum() / len(mask.ravel()) | |
| print(H_m, inliers_num_rate) | |
| return H_m, inliers_num_rate, goodMatch, mask | |
| def align_image_pair(self, H_m, inliers_num_rate, query_image, refer_image): | |
| if H_m is not None: | |
| h, w = self.model_image_height, self.model_image_width | |
| # 确保图像格式正确 - 从 (channels, height, width) 转换为 (height, width, channels) | |
| if len(query_image.shape) == 3 and query_image.shape[0] in [1, 3]: | |
| # 如果是 (channels, height, width) 格式,转换为 (height, width, channels) | |
| query_image = query_image.transpose(1, 2, 0) | |
| # 如果是单通道,去掉通道维度 | |
| if query_image.shape[2] == 1: | |
| query_image = query_image.squeeze(2) | |
| if len(refer_image.shape) == 3 and refer_image.shape[0] in [1, 3]: | |
| # 如果是 (channels, height, width) 格式,转换为 (height, width, channels) | |
| refer_image = refer_image.transpose(1, 2, 0) | |
| # 如果是单通道,去掉通道维度 | |
| if refer_image.shape[2] == 1: | |
| refer_image = refer_image.squeeze(2) | |
| # 确保图像是 uint8 类型 | |
| if query_image.dtype != np.uint8: | |
| query_image = (query_image * 255).astype(np.uint8) | |
| if refer_image.dtype != np.uint8: | |
| refer_image = (refer_image * 255).astype(np.uint8) | |
| print(f"query_image shape after conversion: {query_image.shape}") | |
| print(f"refer_image shape after conversion: {refer_image.shape}") | |
| try: | |
| query_align = cv2.warpPerspective(query_image, H_m, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) | |
| merged = np.zeros((h, w, 3), dtype=np.uint8) | |
| # 处理查询图像 | |
| if len(query_align.shape) == 3: | |
| query_align_gray = cv2.cvtColor(query_align, cv2.COLOR_BGR2GRAY) | |
| else: | |
| query_align_gray = query_align | |
| # 处理参考图像 | |
| if len(refer_image.shape) == 3: | |
| refer_gray = cv2.cvtColor(refer_image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| refer_gray = refer_image | |
| merged[:, :, 0] = query_align_gray | |
| merged[:, :, 1] = refer_gray | |
| return merged | |
| except Exception as e: | |
| print(f"图像对齐过程中出错: {e}") | |
| return None | |
| print("Matched Failed!") | |
| return None | |
| def draw_result(self, query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status): | |
| def drawMatches(imageA, imageB, kpsA, kpsB, matches, status): | |
| # initialize the output visualization image | |
| (hA, wA) = imageA.shape[:2] | |
| (hB, wB) = imageB.shape[:2] | |
| print("query_image shape:", imageA.shape) | |
| print("refer_image shape:", imageB.shape) | |
| vis = np.zeros((max(hA, hB), wA + wB, 3), dtype="uint8") | |
| if len(imageA.shape) == 2: | |
| imageA = cv2.cvtColor(imageA, cv2.COLOR_GRAY2RGB) | |
| imageB = cv2.cvtColor(imageB, cv2.COLOR_GRAY2RGB) | |
| vis[0:hA, 0:wA] = imageA | |
| vis[0:hB, wA:] = imageB | |
| # loop over the matches | |
| for match, s in zip(matches, status): | |
| trainIdx, queryIdx = match.trainIdx, match.queryIdx | |
| # only process the match if the keypoint was successfully | |
| # matched | |
| if s == 1 or s == True: | |
| # draw the match | |
| ptA = (int(kpsA[queryIdx].pt[0]), int(kpsA[queryIdx].pt[1])) | |
| ptB = (int(kpsB[trainIdx].pt[0]) + wA, int(kpsB[trainIdx].pt[1])) | |
| cv2.line(vis, ptA, ptB, (0, 255, 0), 2) | |
| # return the visualization | |
| return vis | |
| query_np = np.array([kp.pt for kp in cv_kpts_query]) | |
| refer_np = np.array([kp.pt for kp in cv_kpts_refer]) | |
| refer_np[:, 0] += query_image.shape[1] | |
| matched_image = drawMatches(query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status) | |
| plt.figure(dpi=300) | |
| plt.scatter(query_np[:, 0], query_np[:, 1], s=1, c='r') | |
| plt.scatter(refer_np[:, 0], refer_np[:, 1], s=1, c='r') | |
| plt.axis('off') | |
| plt.title('Match Result, #goodMatch: {}'.format(np.sum(status))) | |
| plt.imshow(cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB)) | |
| # Instead of showing the plot, we return the image as a numpy array | |
| # Convert the plot to a numpy array | |
| plt.tight_layout() | |
| fig = plt.gcf() | |
| fig.canvas.draw() | |
| # Get the RGBA buffer from the figure | |
| w, h = fig.canvas.get_width_height() | |
| buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| buf.shape = (h, w, 3) | |
| plt.close() | |
| return buf | |
| def inference(self, query_image, refer_image): | |
| try: | |
| query_processed = self.load_image(query_image) | |
| refer_processed = self.load_image(refer_image) | |
| keypoints, descriptors, query_tensor, refer_tensor = self.model_run_pair(query_processed, refer_processed) | |
| goodMatch, cv_kpts_query, cv_kpts_refer, status = self.match(keypoints, descriptors) | |
| H_m, inliers_num_rate, filtered_goodMatch, mask = self.compute_homography(goodMatch, cv_kpts_query, cv_kpts_refer) | |
| # 创建过滤后的状态列表,只包含通过单应性矩阵验证的匹配点 | |
| if mask is not None: | |
| filtered_status = [s for s, m in zip(status, mask.ravel()) if m == 1] | |
| else: | |
| filtered_status = [] | |
| merged = self.align_image_pair(H_m, inliers_num_rate, query_tensor, refer_tensor) | |
| match_show = self.draw_result(query_processed, refer_processed, cv_kpts_query, cv_kpts_refer, filtered_goodMatch, filtered_status) | |
| return merged, match_show | |
| except Exception as e: | |
| print(f"推理过程中出现错误: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # 返回错误信息或默认图像 | |
| h, w = self.model_image_height, self.model_image_height | |
| error_image = np.zeros((h, w, 3), dtype=np.uint8) | |
| cv2.putText(error_image, f"Error: {str(e)}", (10, h//2), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| return error_image, error_image |