Spaces:
Sleeping
Sleeping
Hongyang Li
commited on
Upload 2 files
Browse files- utils/common_util.py +301 -0
- utils/image_process.py +403 -0
utils/common_util.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn import functional as F
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import cv2
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import scipy.stats as st
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
|
| 14 |
+
# 这个函数是在推理的时候用的,训练的时候没有使用
|
| 15 |
+
def remove_borders(keypoints, scores, border: int, height: int, width: int):
|
| 16 |
+
""" Removes keypoints too close to the border """
|
| 17 |
+
'''
|
| 18 |
+
keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
|
| 19 |
+
scores: 每个关键点对应的分数,形状为 (N,)。
|
| 20 |
+
border: 表示需要移除的边界宽度。 推理时预设的是4像素
|
| 21 |
+
height: 图像的高度。
|
| 22 |
+
width: 图像的宽度。
|
| 23 |
+
'''
|
| 24 |
+
# 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
|
| 25 |
+
mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
|
| 26 |
+
# 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
|
| 27 |
+
mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
|
| 28 |
+
# 组合掩码 (必须同时满足高度和宽度条件)
|
| 29 |
+
mask = mask_h & mask_w # 所以这个mask是判定条件
|
| 30 |
+
# 返回过滤后的关键点和分数
|
| 31 |
+
return keypoints[mask], scores[mask]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def simple_nms(scores, nms_radius: int):
|
| 35 |
+
"""
|
| 36 |
+
快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
|
| 37 |
+
|
| 38 |
+
参数:
|
| 39 |
+
scores: 关键点分数图 (B, H, W) 或 (H, W)
|
| 40 |
+
nms_radius: NMS邻域半径
|
| 41 |
+
|
| 42 |
+
返回:
|
| 43 |
+
NMS处理后的分数图
|
| 44 |
+
"""
|
| 45 |
+
assert (nms_radius >= 0)
|
| 46 |
+
# 计算NMS窗口大小 (2*半径+1)
|
| 47 |
+
size = nms_radius * 2 + 1
|
| 48 |
+
avg_size = 2
|
| 49 |
+
# 定义最大池化函数 (使用固定步长1和适当填充)
|
| 50 |
+
def max_pool(x):
|
| 51 |
+
return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
|
| 52 |
+
|
| 53 |
+
# 创建与输入相同形状的零张量
|
| 54 |
+
zeros = torch.zeros_like(scores)
|
| 55 |
+
# max_map = max_pool(scores)
|
| 56 |
+
|
| 57 |
+
# 步骤1: 识别局部最大值点
|
| 58 |
+
# 比较每个点与其邻域内的最大值
|
| 59 |
+
max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
|
| 60 |
+
# 步骤2: 添加微小随机扰动 (避免多个相同最大值)
|
| 61 |
+
# 生成 [0, 0.1) 范围内的随机数
|
| 62 |
+
max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10
|
| 63 |
+
# 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
|
| 64 |
+
# 非局部最大值点置零
|
| 65 |
+
max_mask_[~max_mask] = 0
|
| 66 |
+
|
| 67 |
+
# 步骤3: 对扰动后的图再次应用NMS
|
| 68 |
+
# 识别扰动后仍然是局部最大值的点
|
| 69 |
+
mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。
|
| 70 |
+
|
| 71 |
+
# 步骤4: 保留局部最大值点,其他点置零
|
| 72 |
+
return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def pre_processing(data):
|
| 76 |
+
""" Enhance retinal images """
|
| 77 |
+
train_imgs = datasets_normalized(data)
|
| 78 |
+
train_imgs = clahe_equalized(train_imgs)
|
| 79 |
+
train_imgs = adjust_gamma(train_imgs, 1.2)
|
| 80 |
+
|
| 81 |
+
train_imgs = train_imgs / 255.
|
| 82 |
+
|
| 83 |
+
return train_imgs.astype(np.float32)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def rgb2gray(rgb):
|
| 87 |
+
""" Convert RGB image to gray image """
|
| 88 |
+
r, g, b = rgb.split()
|
| 89 |
+
return g
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
|
| 93 |
+
# 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
|
| 94 |
+
def clahe_equalized(images):
|
| 95 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 96 |
+
# clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
|
| 97 |
+
# tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
|
| 98 |
+
images_equalized = np.empty(images.shape)
|
| 99 |
+
images_equalized[:, :] = clahe.apply(np.array(images[:, :],
|
| 100 |
+
dtype=np.uint8))
|
| 101 |
+
|
| 102 |
+
return images_equalized
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def datasets_normalized(images):
|
| 106 |
+
# 归一化之后还需要把值映射到0到255
|
| 107 |
+
# images_normalized = np.empty(images.shape)
|
| 108 |
+
images_std = np.std(images)
|
| 109 |
+
images_mean = np.mean(images)
|
| 110 |
+
images_normalized = (images - images_mean) / (images_std + 1e-6)
|
| 111 |
+
minv = np.min(images_normalized)
|
| 112 |
+
images_normalized = ((images_normalized - minv) /
|
| 113 |
+
(np.max(images_normalized) - minv)) * 255
|
| 114 |
+
|
| 115 |
+
return images_normalized
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def adjust_gamma(images, gamma=1.0):
|
| 119 |
+
invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。
|
| 120 |
+
table = np.array([((i / 255.0) ** invGamma) * 255
|
| 121 |
+
for i in np.arange(0, 256)]).astype("uint8")
|
| 122 |
+
# 预计算伽��校正的转换值,用于快速查找。
|
| 123 |
+
# 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
|
| 124 |
+
# 生成过程:
|
| 125 |
+
# i / 255.0: 将像素值归一化到 [0, 1] 范围。
|
| 126 |
+
# (i / 255.0) ** invGamma: 应用伽马校正公式。
|
| 127 |
+
# * 255: 将归一化后的值还原到 [0, 255] 范围。
|
| 128 |
+
# astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
|
| 129 |
+
new_images = np.empty(images.shape)
|
| 130 |
+
new_images[:, :] = cv2.LUT(np.array(images[:, :],
|
| 131 |
+
dtype=np.uint8), table)
|
| 132 |
+
# cv2.LUT: OpenCV 的快速像素值映射函数。
|
| 133 |
+
# 输入图像的每个像素值通过查找表 table 进行伽马校正。
|
| 134 |
+
# 大幅提高效率,避免逐像素计算。
|
| 135 |
+
|
| 136 |
+
return new_images
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
|
| 140 |
+
"""
|
| 141 |
+
在检测器预测上应用非极大值抑制 (NMS)
|
| 142 |
+
|
| 143 |
+
参数:
|
| 144 |
+
detector_pred: 检测器预测 (B, 1, H, W)
|
| 145 |
+
nms_thresh: NMS阈值
|
| 146 |
+
nms_size: NMS邻域大小
|
| 147 |
+
detector_label: 检测器标签 (可选)
|
| 148 |
+
mask: 是否使用标签掩码 (当前未实现)
|
| 149 |
+
|
| 150 |
+
返回:
|
| 151 |
+
关键点位置列表 (每个元素是 (N, 2) 的数组)
|
| 152 |
+
"""
|
| 153 |
+
# 创建预测副本 (避免修改原始数据)
|
| 154 |
+
detector_pred = detector_pred.clone().detach()
|
| 155 |
+
# 获取批次大小和图像尺寸
|
| 156 |
+
B, _, h, w = detector_pred.shape
|
| 157 |
+
|
| 158 |
+
# if mask:
|
| 159 |
+
# assert detector_label is not None
|
| 160 |
+
# detector_pred[detector_pred < nms_thresh] = 0
|
| 161 |
+
# label_mask = detector_label
|
| 162 |
+
#
|
| 163 |
+
# # more area
|
| 164 |
+
#
|
| 165 |
+
# detector_label = detector_label.long().cpu().numpy()
|
| 166 |
+
# detector_label = detector_label.astype(np.uint8)
|
| 167 |
+
# kernel = np.ones((3, 3), np.uint8)
|
| 168 |
+
# label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
|
| 169 |
+
# for s in range(len(detector_label))])
|
| 170 |
+
# label_mask = torch.from_numpy(label_mask).unsqueeze(1)
|
| 171 |
+
# detector_pred[label_mask > 1e-6] = 0
|
| 172 |
+
|
| 173 |
+
# 应用快速NMS算法
|
| 174 |
+
scores = simple_nms(detector_pred, nms_size)
|
| 175 |
+
# 重塑分数图形状 (B, H, W)
|
| 176 |
+
scores = scores.reshape(B, h, w)
|
| 177 |
+
# 找出分数高于阈值的点
|
| 178 |
+
points = [
|
| 179 |
+
torch.nonzero(s > nms_thresh)
|
| 180 |
+
for s in scores]
|
| 181 |
+
# 提取这些点的分数值
|
| 182 |
+
scores = [s[tuple(k.t())] for s, k in zip(scores, points)]
|
| 183 |
+
# 移除靠近边界的点
|
| 184 |
+
points, scores = list(zip(*[
|
| 185 |
+
remove_borders(k, s, 8, h, w)
|
| 186 |
+
for k, s in zip(points, scores)]))
|
| 187 |
+
# 翻转坐标顺序: [y, x] -> [x, y]
|
| 188 |
+
points = [torch.flip(k, [1]).long() for k in points]
|
| 189 |
+
|
| 190 |
+
return points
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# 实际上模型生成的描述符是1/8尺寸*1/8尺寸的描述符特征,这里是上采样到原尺寸的
|
| 194 |
+
# 这个是在PKE算损失的时候用的
|
| 195 |
+
def sample_keypoint_desc(keypoints, descriptors, s: int = 8):
|
| 196 |
+
"""
|
| 197 |
+
在关键点位置采样描述符
|
| 198 |
+
|
| 199 |
+
参数:
|
| 200 |
+
keypoints: 关键点坐标 (B, N, 2) 格式 [x, y]
|
| 201 |
+
descriptors: 描述符图 (B, C, H, W)
|
| 202 |
+
s: 描述符图相对于原始图像的下采样比例
|
| 203 |
+
|
| 204 |
+
返回:
|
| 205 |
+
采样后的描述符 (B, C, N)
|
| 206 |
+
"""
|
| 207 |
+
# 获取描述符张量的形状信息,用于后续处理
|
| 208 |
+
b, c, h, w = descriptors.shape # 原始输入 descriptors: (b, c, h, w)
|
| 209 |
+
|
| 210 |
+
# 克隆关键点并将其转换为浮点类型,以便进行坐标计算
|
| 211 |
+
keypoints = keypoints.clone().float()
|
| 212 |
+
|
| 213 |
+
# 将关键点坐标归一化到范围 (0, 1)
|
| 214 |
+
keypoints /= torch.tensor([(w * s - 1), (h * s - 1)]).to(keypoints)[None]
|
| 215 |
+
|
| 216 |
+
# 将关键点坐标缩放到范围 (-1, 1),以适应 grid_sample 函数的要求
|
| 217 |
+
keypoints = keypoints * 2 - 1
|
| 218 |
+
|
| 219 |
+
# 根据 PyTorch 版本准备 grid_sample 函数的参数,确保兼容性
|
| 220 |
+
args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
|
| 221 |
+
|
| 222 |
+
# 使用 grid_sample 函数在关键点位置插值描述符
|
| 223 |
+
descriptors = torch.nn.functional.grid_sample(
|
| 224 |
+
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) # 经过 grid_sample: (b, c, 1, n) n是关键点数量
|
| 225 |
+
|
| 226 |
+
# 对描述符进行 L2 归一化,使其长度为 1(1个像素),以便后续处理
|
| 227 |
+
descriptors = torch.nn.functional.normalize(
|
| 228 |
+
descriptors.reshape(b, c, -1), p=2, dim=1) # reshape 后: (b, c, n) channel=256
|
| 229 |
+
|
| 230 |
+
# 返回处理后的描述符
|
| 231 |
+
return descriptors
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# 这个是在模型算损失的时候用的,可以同时处理关键点和被映射后的关键点损失
|
| 235 |
+
def sample_descriptors(detector_pred, descriptor_pred, affine_descriptor_pred, grid_inverse,
|
| 236 |
+
nms_size=10, nms_thresh=0.1, scale=8, affine_detector_pred=None):
|
| 237 |
+
"""
|
| 238 |
+
基于关键点采样描述符
|
| 239 |
+
|
| 240 |
+
参数:
|
| 241 |
+
detector_pred: 原始图像的检测器预测 (B, 1, H, W)
|
| 242 |
+
descriptor_pred: 原始图像的描述符预测 (B, C, H, W)
|
| 243 |
+
affine_descriptor_pred: 仿射图像的���述符预测 (B, C, H, W)
|
| 244 |
+
grid_inverse: 逆变换网格 (B, H, W, 2)
|
| 245 |
+
nms_size: NMS邻域大小
|
| 246 |
+
nms_thresh: NMS阈值
|
| 247 |
+
scale: 描述符图相对于原始图像的下采样比例
|
| 248 |
+
affine_detector_pred: 仿射图像的检测器预测 (可选)
|
| 249 |
+
|
| 250 |
+
返回:
|
| 251 |
+
descriptors: 原始图像关键点的描述符列表
|
| 252 |
+
affine_descriptors: 仿射图像对应关键点的描述符列表
|
| 253 |
+
keypoints: 原始图像的关键点位置列表
|
| 254 |
+
"""
|
| 255 |
+
# 获取批次大小和图像尺寸
|
| 256 |
+
B, _, h, w = detector_pred.shape
|
| 257 |
+
|
| 258 |
+
# 应用NMS获取关键点位置
|
| 259 |
+
keypoints = nms(detector_pred, nms_size=nms_size, nms_thresh=nms_thresh)
|
| 260 |
+
|
| 261 |
+
# 使用逆变换网格将关键点映射到仿射空间
|
| 262 |
+
affine_keypoints = [
|
| 263 |
+
grid_inverse[s, k[:, 1].long(), k[:, 0].long()] # 使用网格插值
|
| 264 |
+
for s, k in enumerate(keypoints)
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
# 初始化存储列表
|
| 268 |
+
kp = [] # 过滤后的原始关键点
|
| 269 |
+
affine_kp = [] # 过滤后的仿射关键点
|
| 270 |
+
|
| 271 |
+
# 处理每个样本
|
| 272 |
+
for s, k in enumerate(affine_keypoints):
|
| 273 |
+
# 过滤超出仿射图像边界的点
|
| 274 |
+
idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & (k[:, 1] > -1)
|
| 275 |
+
|
| 276 |
+
# 存储过滤后的原始关键点
|
| 277 |
+
kp.append(keypoints[s][idx])
|
| 278 |
+
|
| 279 |
+
# 获取过滤后的仿射关键点
|
| 280 |
+
ak = k[idx]
|
| 281 |
+
|
| 282 |
+
# 将归一化坐标转换回像素坐标
|
| 283 |
+
ak[:, 0] = (ak[:, 0] + 1) / 2 * (w - 1) # x坐标
|
| 284 |
+
ak[:, 1] = (ak[:, 1] + 1) / 2 * (h - 1) # y坐标
|
| 285 |
+
|
| 286 |
+
# 存储转换后的仿射关键点
|
| 287 |
+
affine_kp.append(ak)
|
| 288 |
+
|
| 289 |
+
# 在原始图像关键点位置采样描述符
|
| 290 |
+
descriptors = [
|
| 291 |
+
sample_keypoint_desc(k[None], d[None], s=scale)[0]
|
| 292 |
+
for k, d in zip(kp, descriptor_pred)
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
# 在仿射图像关键点位置采样描述符
|
| 296 |
+
affine_descriptors = [
|
| 297 |
+
sample_keypoint_desc(k[None], d[None], s=scale)[0]
|
| 298 |
+
for k, d in zip(affine_kp, affine_descriptor_pred)
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
return descriptors, affine_descriptors, keypoints
|
utils/image_process.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision import transforms as T
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
import scipy.stats as st
|
| 13 |
+
# import config
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_gaussian_kernel(kernlen=21, nsig=5):
|
| 19 |
+
"""
|
| 20 |
+
生成二维高斯核并返回一个不可训练的核权重,用于创建高斯热图
|
| 21 |
+
|
| 22 |
+
参数:
|
| 23 |
+
kernlen (int): 核的大小(边长),必须是奇数
|
| 24 |
+
nsig (float): 高斯分布的标准差(控制分布的宽度)
|
| 25 |
+
|
| 26 |
+
返回:
|
| 27 |
+
torch.Tensor: 形状为(1, 1, kernlen, kernlen)的高斯核张量
|
| 28 |
+
"""
|
| 29 |
+
# 1. 计算采样间隔
|
| 30 |
+
# 在[-nsig, nsig]范围内均匀采样kernlen个点
|
| 31 |
+
interval = (2 * nsig + 1.) / kernlen
|
| 32 |
+
# 2. 创建一维坐标点数组
|
| 33 |
+
# 从-nsig-interval/2到nsig+interval/2,共kernlen+1个点
|
| 34 |
+
x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
|
| 35 |
+
# 3. 计算一维高斯核
|
| 36 |
+
# 使用标准正态分布的累积分布函数(CDF)的差分
|
| 37 |
+
kern1d = np.diff(st.norm.cdf(x))
|
| 38 |
+
# 4. 创建二维高斯核
|
| 39 |
+
# 通过两个一维高斯核的外积(outer product)
|
| 40 |
+
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
|
| 41 |
+
# 5. 归一化核
|
| 42 |
+
# 使所有元素之和为1,确保核的总"能量"为1
|
| 43 |
+
kernel = kernel_raw / kernel_raw.sum()
|
| 44 |
+
# 6. 转换为PyTorch张量并调整维度
|
| 45 |
+
# 添加批次和通道维度 (1, 1, H, W)
|
| 46 |
+
kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
|
| 47 |
+
# 7. 创建不可训练的高斯核权重参数
|
| 48 |
+
weight = torch.nn.Parameter(data=kernel, requires_grad=False)
|
| 49 |
+
# 8. 将值归一化到[0, 1]范围
|
| 50 |
+
# 使最小值为0,最大值为1
|
| 51 |
+
weight = (weight - weight.min()) / (weight.max() - weight.min())
|
| 52 |
+
|
| 53 |
+
return weight
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# 对用于监督的点标签进行高斯模糊
|
| 57 |
+
def label_gaussian_blur(label_point_positions, kernel, stride=1):
|
| 58 |
+
'''
|
| 59 |
+
对用于监督的点标签进行高斯模糊
|
| 60 |
+
该函数通过应用高斯卷积核来模糊标签点,以生成更平滑的热图
|
| 61 |
+
参数:
|
| 62 |
+
label_point_positions: 标签点的位置,形状为(B, C, H, W)的张量
|
| 63 |
+
kernel: 高斯卷积核,形状为(1, 1, Hk, Wk)的张量
|
| 64 |
+
stride: 卷积步长,默认为1
|
| 65 |
+
返回值:
|
| 66 |
+
blurred_label: 模糊后的标签热图,形状与输入相同
|
| 67 |
+
'''
|
| 68 |
+
# 应用高斯卷积进行模糊
|
| 69 |
+
blurred_label = F.conv2d(label_point_positions, kernel, stride=stride, padding=(kernel.shape[-1] - 1) // 2)
|
| 70 |
+
# 裁剪值范围:限制热图最大值为1
|
| 71 |
+
blurred_label[blurred_label > 1] = 1
|
| 72 |
+
|
| 73 |
+
return blurred_label
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def affine_images(images, used_for='detector'):
|
| 77 |
+
'''
|
| 78 |
+
根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强
|
| 79 |
+
参数:
|
| 80 |
+
images: 输入图像张量,形状为 (B, C, H, W)
|
| 81 |
+
used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器
|
| 82 |
+
返回值:
|
| 83 |
+
CPU上的仿射变换后的图像 output.detach().clone()
|
| 84 |
+
仿射变换网格 grid
|
| 85 |
+
逆仿射变换网格 grid_inverse
|
| 86 |
+
'''
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
Perform affine transformation on images
|
| 90 |
+
:param images: (B, C, H, W)
|
| 91 |
+
:param keypoint_labels: corresponding labels
|
| 92 |
+
:param value_map: value maps, used to record history learned geo_points
|
| 93 |
+
:return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
h, w = images.shape[2:] # 获取图像的尺寸
|
| 97 |
+
theta = None
|
| 98 |
+
thetaI = None
|
| 99 |
+
|
| 100 |
+
# 对一个batch中的每个图像进行随机仿射变换参数的生成
|
| 101 |
+
for i in range(len(images)):
|
| 102 |
+
if used_for == 'detector':
|
| 103 |
+
affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15], # 旋转角度范围:±15度
|
| 104 |
+
translate=[0.2, 0.2], # 平移范围:±20%
|
| 105 |
+
scale_ranges=[0.9, 1.35], # 缩放范围:0.9-1.35倍
|
| 106 |
+
shears=None, # 不使用剪切变换
|
| 107 |
+
img_size=[h, w])
|
| 108 |
+
else:
|
| 109 |
+
affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3], # 旋转角度范围:±3度
|
| 110 |
+
translate=[0.1, 0.1], # 平移范围:宽高的10%
|
| 111 |
+
scale_ranges=[0.9, 1.1], # 缩放范围:0.9到1.1倍
|
| 112 |
+
shears=None, # 不使用剪切变换
|
| 113 |
+
img_size=[h, w])
|
| 114 |
+
# 根据仿射变换参数计算变换矩阵和逆变换矩阵
|
| 115 |
+
angle = -affine_params[0] * math.pi / 180
|
| 116 |
+
theta_ = torch.tensor([
|
| 117 |
+
[1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]],
|
| 118 |
+
[math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]],
|
| 119 |
+
[0, 0, 1]
|
| 120 |
+
], dtype=torch.float).to(images)
|
| 121 |
+
thetaI_ = theta_.inverse()
|
| 122 |
+
theta_ = theta_[:2]
|
| 123 |
+
thetaI_ = thetaI_[:2]
|
| 124 |
+
|
| 125 |
+
# 将变换矩阵和逆变换矩阵叠成一个batch
|
| 126 |
+
theta_ = theta_.unsqueeze(0)
|
| 127 |
+
thetaI_ = thetaI_.unsqueeze(0)
|
| 128 |
+
theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了
|
| 129 |
+
thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_))
|
| 130 |
+
|
| 131 |
+
# 根据变换矩阵生成网格
|
| 132 |
+
# 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。
|
| 133 |
+
grid = F.affine_grid(theta, images.size(), align_corners=True)
|
| 134 |
+
grid = grid.to(images) # 将参数移动到GPU上
|
| 135 |
+
grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True)
|
| 136 |
+
grid_inverse = grid_inverse.to(images) # 将参数移动到GPU上
|
| 137 |
+
output = F.grid_sample(images, grid, align_corners=True) # 对图像进行采样得到仿射变换后的图像
|
| 138 |
+
|
| 139 |
+
# 对于用于描述符的情况进行一些额外的随机增强操作
|
| 140 |
+
if used_for == 'descriptor':
|
| 141 |
+
if random.random() >= 0.4:
|
| 142 |
+
# 颜色抖动(其实是灰度变化)
|
| 143 |
+
output = output.repeat(1, 3, 1, 1) # 将单通道图像复制为三通道(颜色变换需要RGB)
|
| 144 |
+
output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output)
|
| 145 |
+
output = T.Grayscale()(output) # 灰度化
|
| 146 |
+
|
| 147 |
+
return output.detach().clone(), grid, grid_inverse
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def affine_images_with_mask(images, used_for='detector'):
|
| 151 |
+
'''
|
| 152 |
+
仿射变换图像后,额外返回一个公共区域mask
|
| 153 |
+
根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强
|
| 154 |
+
参数:
|
| 155 |
+
images: 输入图像张量,形状为 (B, C, H, W)
|
| 156 |
+
used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器
|
| 157 |
+
返回值:
|
| 158 |
+
CPU上的仿射变换后的图像 output.detach().clone()
|
| 159 |
+
仿射变换网格 grid
|
| 160 |
+
逆仿射变换网格 grid_inverse
|
| 161 |
+
'''
|
| 162 |
+
|
| 163 |
+
"""
|
| 164 |
+
Perform affine transformation on images
|
| 165 |
+
:param images: (B, C, H, W)
|
| 166 |
+
:param keypoint_labels: corresponding labels
|
| 167 |
+
:param value_map: value maps, used to record history learned geo_points
|
| 168 |
+
:return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
h, w = images.shape[2:] # 获取图像的尺寸
|
| 172 |
+
theta = None
|
| 173 |
+
thetaI = None
|
| 174 |
+
|
| 175 |
+
# ---------------------------------------------------------------------
|
| 176 |
+
# 新增: 创建原始图像坐标网格
|
| 177 |
+
base_grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), (1, 1, h, w), align_corners=True)
|
| 178 |
+
base_grid = base_grid.expand(images.size(0), *base_grid.shape[1:])
|
| 179 |
+
# ---------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
# 对一个batch中的每个图像进行随机仿射变换参数的生成
|
| 182 |
+
for i in range(len(images)):
|
| 183 |
+
if used_for == 'detector':
|
| 184 |
+
affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15], # 旋转角度范围:±15度
|
| 185 |
+
translate=[0.2, 0.2], # 平移范围:±20%
|
| 186 |
+
scale_ranges=[0.9, 1.35], # 缩放范围:0.9-1.35倍
|
| 187 |
+
shears=None, # 不使用剪切变换
|
| 188 |
+
img_size=[h, w])
|
| 189 |
+
else:
|
| 190 |
+
affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3], # 旋转角度范围:±3度
|
| 191 |
+
translate=[0.1, 0.1], # 平移范围:宽高的10%
|
| 192 |
+
scale_ranges=[0.9, 1.1], # 缩放范围:0.9到1.1倍
|
| 193 |
+
shears=None, # 不使用剪切变换
|
| 194 |
+
img_size=[h, w])
|
| 195 |
+
# 根据仿射变换参数计算变换矩阵和逆变换矩阵
|
| 196 |
+
angle = -affine_params[0] * math.pi / 180
|
| 197 |
+
theta_ = torch.tensor([
|
| 198 |
+
[1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]],
|
| 199 |
+
[math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]],
|
| 200 |
+
[0, 0, 1]
|
| 201 |
+
], dtype=torch.float).to(images)
|
| 202 |
+
thetaI_ = theta_.inverse()
|
| 203 |
+
theta_ = theta_[:2]
|
| 204 |
+
thetaI_ = thetaI_[:2]
|
| 205 |
+
|
| 206 |
+
# 将变换矩阵和逆变换矩阵叠成一个batch
|
| 207 |
+
theta_ = theta_.unsqueeze(0)
|
| 208 |
+
thetaI_ = thetaI_.unsqueeze(0)
|
| 209 |
+
theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了
|
| 210 |
+
thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_))
|
| 211 |
+
|
| 212 |
+
# 根据变换矩阵生成网格
|
| 213 |
+
# 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。
|
| 214 |
+
grid = F.affine_grid(theta, images.size(), align_corners=True)
|
| 215 |
+
grid = grid.to(images) # 将参数移动到GPU上
|
| 216 |
+
grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True)
|
| 217 |
+
grid_inverse = grid_inverse.to(images) # 将参数移动到GPU上
|
| 218 |
+
output = F.grid_sample(images, grid, align_corners=True) # 对图像进行采样得到仿射变换后的图像
|
| 219 |
+
|
| 220 |
+
# ---------------------------------------------------------------------
|
| 221 |
+
# 新增: 计算有效区域mask
|
| 222 |
+
# 1. 创建全1mask (表示所有像素初始有效)
|
| 223 |
+
valid_mask = torch.ones_like(images[:, :1]) # (B, 1, H, W)
|
| 224 |
+
# 2. 应用相同的变换到mask上
|
| 225 |
+
transformed_mask = F.grid_sample(valid_mask, grid, align_corners=True)
|
| 226 |
+
# 3. 反向映射回原始图像坐标
|
| 227 |
+
valid_mask = F.grid_sample(transformed_mask, grid_inverse, align_corners=True)
|
| 228 |
+
# 4. 二值化 (大于0.5视为有效)
|
| 229 |
+
valid_mask = (valid_mask > 0.5).float()
|
| 230 |
+
# ---------------------------------------------------------------------
|
| 231 |
+
|
| 232 |
+
# 对于用于描述符的情况进行一些额外的随机增强操作
|
| 233 |
+
if used_for == 'descriptor':
|
| 234 |
+
if random.random() >= 0.4:
|
| 235 |
+
# 颜色抖动(其实是灰度变化)
|
| 236 |
+
output = output.repeat(1, 3, 1, 1) # 将单通道图像复制为三通道(颜色变换需要RGB)
|
| 237 |
+
output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output)
|
| 238 |
+
output = T.Grayscale()(output) # 灰度化
|
| 239 |
+
|
| 240 |
+
return output.detach().clone(), grid, grid_inverse, valid_mask.detach().clone()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def pre_processing(data: np.ndarray) -> np.ndarray:
|
| 244 |
+
"""
|
| 245 |
+
视网膜图像预处理流水线
|
| 246 |
+
包含标准化、CLAHE对比度增强和伽马校正
|
| 247 |
+
|
| 248 |
+
参数:
|
| 249 |
+
image: 输入视网膜图像 (单通道,numpy数组)
|
| 250 |
+
|
| 251 |
+
返回:
|
| 252 |
+
预处理后的图像 (float32, 范围[0,1])
|
| 253 |
+
"""
|
| 254 |
+
train_imgs = datasets_normalized(data)
|
| 255 |
+
train_imgs = clahe_equalized(train_imgs)
|
| 256 |
+
train_imgs = adjust_gamma(train_imgs, 1.2)
|
| 257 |
+
train_imgs = train_imgs / 255. # 最终归一化到[0,1]范围
|
| 258 |
+
return train_imgs.astype(np.float32)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def datasets_normalized(images: np.ndarray) -> np.ndarray:
|
| 262 |
+
# 归一化之后还需要把值映射到0到255
|
| 263 |
+
# images_normalized = np.empty(images.shape)
|
| 264 |
+
# 计算全局统计量
|
| 265 |
+
images_std = np.std(images)
|
| 266 |
+
images_mean = np.mean(images)
|
| 267 |
+
# 应用标准化: (x - mean) / std
|
| 268 |
+
# 添加微小值避免除以零
|
| 269 |
+
images_normalized = (images - images_mean) / (images_std + 1e-6)
|
| 270 |
+
# 线性映射到[0,255]范围
|
| 271 |
+
minv = np.min(images_normalized)
|
| 272 |
+
images_normalized = ((images_normalized - minv) / (np.max(images_normalized) - minv)) * 255
|
| 273 |
+
|
| 274 |
+
return images_normalized
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def clahe_equalized(images):
|
| 278 |
+
# 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
|
| 279 |
+
# 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
|
| 280 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 281 |
+
# clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
|
| 282 |
+
# tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
|
| 283 |
+
images_equalized = np.empty(images.shape)
|
| 284 |
+
images_equalized[:, :] = clahe.apply(np.array(images[:, :], dtype=np.uint8))
|
| 285 |
+
|
| 286 |
+
return images_equalized
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def adjust_gamma(images, gamma=1.0):
|
| 290 |
+
# γ>1: 压缩高光区域,增强暗部细节 (适合视网膜图像)
|
| 291 |
+
# 创建伽马校正查找表
|
| 292 |
+
invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。
|
| 293 |
+
table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
|
| 294 |
+
# 预计算伽马校正的转换值,用于快速查找。
|
| 295 |
+
# 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
|
| 296 |
+
# 生成过程:
|
| 297 |
+
# i / 255.0: 将像素值归一化到 [0, 1] 范围。
|
| 298 |
+
# (i / 255.0) ** invGamma: 应用伽马校正公式。
|
| 299 |
+
# * 255: 将归一化后的值还原到 [0, 255] 范围。
|
| 300 |
+
# astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
|
| 301 |
+
new_images = np.empty(images.shape)
|
| 302 |
+
# 应用伽马校正
|
| 303 |
+
new_images[:, :] = cv2.LUT(np.array(images[:, :], dtype=np.uint8), table)
|
| 304 |
+
# cv2.LUT: OpenCV 的快速像素值映射函数。
|
| 305 |
+
# 输入图像的每个像素值通过查找表 table 进行伽马校正。
|
| 306 |
+
# 大幅提高效率,避免逐像素计算。
|
| 307 |
+
|
| 308 |
+
return new_images
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def simple_nms(scores, nms_radius: int):
|
| 312 |
+
"""
|
| 313 |
+
快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
|
| 314 |
+
参数:
|
| 315 |
+
scores: 关键点分数图 (B, H, W) 或 (H, W)
|
| 316 |
+
nms_radius: NMS邻域半径
|
| 317 |
+
返回:
|
| 318 |
+
NMS处理后的分数图
|
| 319 |
+
"""
|
| 320 |
+
assert (nms_radius >= 0)
|
| 321 |
+
# 计算NMS窗口大小 (2*半径+1)
|
| 322 |
+
size = nms_radius * 2 + 1
|
| 323 |
+
avg_size = 2
|
| 324 |
+
# 定义最大池化函数 (使用固定步长1和适当填充)
|
| 325 |
+
def max_pool(x):
|
| 326 |
+
return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
|
| 327 |
+
|
| 328 |
+
# 创建与输入相同形状的零张量
|
| 329 |
+
zeros = torch.zeros_like(scores)
|
| 330 |
+
# max_map = max_pool(scores)
|
| 331 |
+
|
| 332 |
+
# 步骤1: 识别局部最大值点
|
| 333 |
+
# 比较每个点与其邻域内的最大值
|
| 334 |
+
max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
|
| 335 |
+
# 步骤2: 添加微小随机扰动 (避免多个相同最大值)
|
| 336 |
+
# 生成 [0, 0.1) 范围内的随机数
|
| 337 |
+
max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10
|
| 338 |
+
# 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
|
| 339 |
+
# 非局部最大值点置零
|
| 340 |
+
max_mask_[~max_mask] = 0
|
| 341 |
+
# 步骤3: 对扰动后的图再次应用NMS
|
| 342 |
+
# 识别扰动后仍然是局部最大值的点
|
| 343 |
+
mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。
|
| 344 |
+
# 步骤4: 保留局部最大值点,其他点置零
|
| 345 |
+
return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def remove_borders(keypoints, scores, border: int, height: int, width: int):
|
| 349 |
+
'''
|
| 350 |
+
从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
|
| 351 |
+
这个函数是在推理的时候用的,训练的时候没有使用
|
| 352 |
+
keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
|
| 353 |
+
scores: 每个关键点对应的分数,形状为 (N,)。
|
| 354 |
+
border: 表示需要移除的边界宽度。 推理时预设的是4像素
|
| 355 |
+
height: 图像的高度。
|
| 356 |
+
width: 图像的宽度。
|
| 357 |
+
'''
|
| 358 |
+
mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
|
| 359 |
+
mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
|
| 360 |
+
mask = mask_h & mask_w # 组合掩码 (必须同时满足高度和宽度条件)
|
| 361 |
+
return keypoints[mask], scores[mask] # 返回过滤后的关键点和分数
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
|
| 365 |
+
"""
|
| 366 |
+
在检测器预测上应用非极大值抑制 (NMS)
|
| 367 |
+
|
| 368 |
+
参数:
|
| 369 |
+
detector_pred: 检测器预测 (B, 1, H, W)
|
| 370 |
+
nms_thresh: NMS阈值
|
| 371 |
+
nms_size: NMS邻域大小
|
| 372 |
+
detector_label: 检测器标签 (可选)
|
| 373 |
+
mask: 是否使用标签掩码 (当前未实现)
|
| 374 |
+
|
| 375 |
+
返回:
|
| 376 |
+
关键点位置列表 (每个元素是 (N, 2) 的数组)
|
| 377 |
+
"""
|
| 378 |
+
detector_pred = detector_pred.clone().detach() # 创建预测副本 (避免修改原始数据)
|
| 379 |
+
B, _, h, w = detector_pred.shape # 获取批次大小和图像尺寸
|
| 380 |
+
# if mask:
|
| 381 |
+
# assert detector_label is not None
|
| 382 |
+
# detector_pred[detector_pred < nms_thresh] = 0
|
| 383 |
+
# label_mask = detector_label
|
| 384 |
+
#
|
| 385 |
+
# # more area
|
| 386 |
+
#
|
| 387 |
+
# detector_label = detector_label.long().cpu().numpy()
|
| 388 |
+
# detector_label = detector_label.astype(np.uint8)
|
| 389 |
+
# kernel = np.ones((3, 3), np.uint8)
|
| 390 |
+
# label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
|
| 391 |
+
# for s in range(len(detector_label))])
|
| 392 |
+
# label_mask = torch.from_numpy(label_mask).unsqueeze(1)
|
| 393 |
+
# detector_pred[label_mask > 1e-6] = 0
|
| 394 |
+
scores = simple_nms(detector_pred, nms_size) # 应用快速NMS算法
|
| 395 |
+
scores = scores.reshape(B, h, w) # 重塑分数图形状 (B, H, W)
|
| 396 |
+
# print(f"scores before thresh {nms_thresh}", scores)
|
| 397 |
+
points = [torch.nonzero(s > nms_thresh) for s in scores] # 找出分数高于阈值的点
|
| 398 |
+
# print("points after thresh", points)
|
| 399 |
+
scores = [s[tuple(k.t())] for s, k in zip(scores, points)] # 提取这些点的分数值
|
| 400 |
+
points, scores = list(zip(*[ remove_borders(k, s, 8, h, w) for k, s in zip(points, scores)])) # 移除靠近边界的点
|
| 401 |
+
points = [torch.flip(k, [1]).long() for k in points] # 翻转坐标顺序: [y, x] -> [x, y]
|
| 402 |
+
|
| 403 |
+
return points
|