Hongyang Li commited on
Commit
bddd311
·
verified ·
1 Parent(s): 8e92669

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/common_util.py +301 -0
  2. utils/image_process.py +403 -0
utils/common_util.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import functional as F
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+
5
+ from torchvision import transforms
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ import torch.nn as nn
10
+ import scipy.stats as st
11
+
12
+
13
+ # 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
14
+ # 这个函数是在推理的时候用的,训练的时候没有使用
15
+ def remove_borders(keypoints, scores, border: int, height: int, width: int):
16
+ """ Removes keypoints too close to the border """
17
+ '''
18
+ keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
19
+ scores: 每个关键点对应的分数,形状为 (N,)。
20
+ border: 表示需要移除的边界宽度。 推理时预设的是4像素
21
+ height: 图像的高度。
22
+ width: 图像的宽度。
23
+ '''
24
+ # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
25
+ mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
26
+ # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
27
+ mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
28
+ # 组合掩码 (必须同时满足高度和宽度条件)
29
+ mask = mask_h & mask_w # 所以这个mask是判定条件
30
+ # 返回过滤后的关键点和分数
31
+ return keypoints[mask], scores[mask]
32
+
33
+
34
+ def simple_nms(scores, nms_radius: int):
35
+ """
36
+ 快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
37
+
38
+ 参数:
39
+ scores: 关键点分数图 (B, H, W) 或 (H, W)
40
+ nms_radius: NMS邻域半径
41
+
42
+ 返回:
43
+ NMS处理后的分数图
44
+ """
45
+ assert (nms_radius >= 0)
46
+ # 计算NMS窗口大小 (2*半径+1)
47
+ size = nms_radius * 2 + 1
48
+ avg_size = 2
49
+ # 定义最大池化函数 (使用固定步长1和适当填充)
50
+ def max_pool(x):
51
+ return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
52
+
53
+ # 创建与输入相同形状的零张量
54
+ zeros = torch.zeros_like(scores)
55
+ # max_map = max_pool(scores)
56
+
57
+ # 步骤1: 识别局部最大值点
58
+ # 比较每个点与其邻域内的最大值
59
+ max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
60
+ # 步骤2: 添加微小随机扰动 (避免多个相同最大值)
61
+ # 生成 [0, 0.1) 范围内的随机数
62
+ max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10
63
+ # 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
64
+ # 非局部最大值点置零
65
+ max_mask_[~max_mask] = 0
66
+
67
+ # 步骤3: 对扰动后的图再次应用NMS
68
+ # 识别扰动后仍然是局部最大值的点
69
+ mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。
70
+
71
+ # 步骤4: 保留局部最大值点,其他点置零
72
+ return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。
73
+
74
+
75
+ def pre_processing(data):
76
+ """ Enhance retinal images """
77
+ train_imgs = datasets_normalized(data)
78
+ train_imgs = clahe_equalized(train_imgs)
79
+ train_imgs = adjust_gamma(train_imgs, 1.2)
80
+
81
+ train_imgs = train_imgs / 255.
82
+
83
+ return train_imgs.astype(np.float32)
84
+
85
+
86
+ def rgb2gray(rgb):
87
+ """ Convert RGB image to gray image """
88
+ r, g, b = rgb.split()
89
+ return g
90
+
91
+
92
+ # 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
93
+ # 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
94
+ def clahe_equalized(images):
95
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
96
+ # clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
97
+ # tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
98
+ images_equalized = np.empty(images.shape)
99
+ images_equalized[:, :] = clahe.apply(np.array(images[:, :],
100
+ dtype=np.uint8))
101
+
102
+ return images_equalized
103
+
104
+
105
+ def datasets_normalized(images):
106
+ # 归一化之后还需要把值映射到0到255
107
+ # images_normalized = np.empty(images.shape)
108
+ images_std = np.std(images)
109
+ images_mean = np.mean(images)
110
+ images_normalized = (images - images_mean) / (images_std + 1e-6)
111
+ minv = np.min(images_normalized)
112
+ images_normalized = ((images_normalized - minv) /
113
+ (np.max(images_normalized) - minv)) * 255
114
+
115
+ return images_normalized
116
+
117
+
118
+ def adjust_gamma(images, gamma=1.0):
119
+ invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。
120
+ table = np.array([((i / 255.0) ** invGamma) * 255
121
+ for i in np.arange(0, 256)]).astype("uint8")
122
+ # 预计算伽��校正的转换值,用于快速查找。
123
+ # 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
124
+ # 生成过程:
125
+ # i / 255.0: 将像素值归一化到 [0, 1] 范围。
126
+ # (i / 255.0) ** invGamma: 应用伽马校正公式。
127
+ # * 255: 将归一化后的值还原到 [0, 255] 范围。
128
+ # astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
129
+ new_images = np.empty(images.shape)
130
+ new_images[:, :] = cv2.LUT(np.array(images[:, :],
131
+ dtype=np.uint8), table)
132
+ # cv2.LUT: OpenCV 的快速像素值映射函数。
133
+ # 输入图像的每个像素值通过查找表 table 进行伽马校正。
134
+ # 大幅提高效率,避免逐像素计算。
135
+
136
+ return new_images
137
+
138
+
139
+ def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
140
+ """
141
+ 在检测器预测上应用非极大值抑制 (NMS)
142
+
143
+ 参数:
144
+ detector_pred: 检测器预测 (B, 1, H, W)
145
+ nms_thresh: NMS阈值
146
+ nms_size: NMS邻域大小
147
+ detector_label: 检测器标签 (可选)
148
+ mask: 是否使用标签掩码 (当前未实现)
149
+
150
+ 返回:
151
+ 关键点位置列表 (每个元素是 (N, 2) 的数组)
152
+ """
153
+ # 创建预测副本 (避免修改原始数据)
154
+ detector_pred = detector_pred.clone().detach()
155
+ # 获取批次大小和图像尺寸
156
+ B, _, h, w = detector_pred.shape
157
+
158
+ # if mask:
159
+ # assert detector_label is not None
160
+ # detector_pred[detector_pred < nms_thresh] = 0
161
+ # label_mask = detector_label
162
+ #
163
+ # # more area
164
+ #
165
+ # detector_label = detector_label.long().cpu().numpy()
166
+ # detector_label = detector_label.astype(np.uint8)
167
+ # kernel = np.ones((3, 3), np.uint8)
168
+ # label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
169
+ # for s in range(len(detector_label))])
170
+ # label_mask = torch.from_numpy(label_mask).unsqueeze(1)
171
+ # detector_pred[label_mask > 1e-6] = 0
172
+
173
+ # 应用快速NMS算法
174
+ scores = simple_nms(detector_pred, nms_size)
175
+ # 重塑分数图形状 (B, H, W)
176
+ scores = scores.reshape(B, h, w)
177
+ # 找出分数高于阈值的点
178
+ points = [
179
+ torch.nonzero(s > nms_thresh)
180
+ for s in scores]
181
+ # 提取这些点的分数值
182
+ scores = [s[tuple(k.t())] for s, k in zip(scores, points)]
183
+ # 移除靠近边界的点
184
+ points, scores = list(zip(*[
185
+ remove_borders(k, s, 8, h, w)
186
+ for k, s in zip(points, scores)]))
187
+ # 翻转坐标顺序: [y, x] -> [x, y]
188
+ points = [torch.flip(k, [1]).long() for k in points]
189
+
190
+ return points
191
+
192
+
193
+ # 实际上模型生成的描述符是1/8尺寸*1/8尺寸的描述符特征,这里是上采样到原尺寸的
194
+ # 这个是在PKE算损失的时候用的
195
+ def sample_keypoint_desc(keypoints, descriptors, s: int = 8):
196
+ """
197
+ 在关键点位置采样描述符
198
+
199
+ 参数:
200
+ keypoints: 关键点坐标 (B, N, 2) 格式 [x, y]
201
+ descriptors: 描述符图 (B, C, H, W)
202
+ s: 描述符图相对于原始图像的下采样比例
203
+
204
+ 返回:
205
+ 采样后的描述符 (B, C, N)
206
+ """
207
+ # 获取描述符张量的形状信息,用于后续处理
208
+ b, c, h, w = descriptors.shape # 原始输入 descriptors: (b, c, h, w)
209
+
210
+ # 克隆关键点并将其转换为浮点类型,以便进行坐标计算
211
+ keypoints = keypoints.clone().float()
212
+
213
+ # 将关键点坐标归一化到范围 (0, 1)
214
+ keypoints /= torch.tensor([(w * s - 1), (h * s - 1)]).to(keypoints)[None]
215
+
216
+ # 将关键点坐标缩放到范围 (-1, 1),以适应 grid_sample 函数的要求
217
+ keypoints = keypoints * 2 - 1
218
+
219
+ # 根据 PyTorch 版本准备 grid_sample 函数的参数,确保兼容性
220
+ args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
221
+
222
+ # 使用 grid_sample 函数在关键点位置插值描述符
223
+ descriptors = torch.nn.functional.grid_sample(
224
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) # 经过 grid_sample: (b, c, 1, n) n是关键点数量
225
+
226
+ # 对描述符进行 L2 归一化,使其长度为 1(1个像素),以便后续处理
227
+ descriptors = torch.nn.functional.normalize(
228
+ descriptors.reshape(b, c, -1), p=2, dim=1) # reshape 后: (b, c, n) channel=256
229
+
230
+ # 返回处理后的描述符
231
+ return descriptors
232
+
233
+
234
+ # 这个是在模型算损失的时候用的,可以同时处理关键点和被映射后的关键点损失
235
+ def sample_descriptors(detector_pred, descriptor_pred, affine_descriptor_pred, grid_inverse,
236
+ nms_size=10, nms_thresh=0.1, scale=8, affine_detector_pred=None):
237
+ """
238
+ 基于关键点采样描述符
239
+
240
+ 参数:
241
+ detector_pred: 原始图像的检测器预测 (B, 1, H, W)
242
+ descriptor_pred: 原始图像的描述符预测 (B, C, H, W)
243
+ affine_descriptor_pred: 仿射图像的���述符预测 (B, C, H, W)
244
+ grid_inverse: 逆变换网格 (B, H, W, 2)
245
+ nms_size: NMS邻域大小
246
+ nms_thresh: NMS阈值
247
+ scale: 描述符图相对于原始图像的下采样比例
248
+ affine_detector_pred: 仿射图像的检测器预测 (可选)
249
+
250
+ 返回:
251
+ descriptors: 原始图像关键点的描述符列表
252
+ affine_descriptors: 仿射图像对应关键点的描述符列表
253
+ keypoints: 原始图像的关键点位置列表
254
+ """
255
+ # 获取批次大小和图像尺寸
256
+ B, _, h, w = detector_pred.shape
257
+
258
+ # 应用NMS获取关键点位置
259
+ keypoints = nms(detector_pred, nms_size=nms_size, nms_thresh=nms_thresh)
260
+
261
+ # 使用逆变换网格将关键点映射到仿射空间
262
+ affine_keypoints = [
263
+ grid_inverse[s, k[:, 1].long(), k[:, 0].long()] # 使用网格插值
264
+ for s, k in enumerate(keypoints)
265
+ ]
266
+
267
+ # 初始化存储列表
268
+ kp = [] # 过滤后的原始关键点
269
+ affine_kp = [] # 过滤后的仿射关键点
270
+
271
+ # 处理每个样本
272
+ for s, k in enumerate(affine_keypoints):
273
+ # 过滤超出仿射图像边界的点
274
+ idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & (k[:, 1] > -1)
275
+
276
+ # 存储过滤后的原始关键点
277
+ kp.append(keypoints[s][idx])
278
+
279
+ # 获取过滤后的仿射关键点
280
+ ak = k[idx]
281
+
282
+ # 将归一化坐标转换回像素坐标
283
+ ak[:, 0] = (ak[:, 0] + 1) / 2 * (w - 1) # x坐标
284
+ ak[:, 1] = (ak[:, 1] + 1) / 2 * (h - 1) # y坐标
285
+
286
+ # 存储转换后的仿射关键点
287
+ affine_kp.append(ak)
288
+
289
+ # 在原始图像关键点位置采样描述符
290
+ descriptors = [
291
+ sample_keypoint_desc(k[None], d[None], s=scale)[0]
292
+ for k, d in zip(kp, descriptor_pred)
293
+ ]
294
+
295
+ # 在仿射图像关键点位置采样描述符
296
+ affine_descriptors = [
297
+ sample_keypoint_desc(k[None], d[None], s=scale)[0]
298
+ for k, d in zip(affine_kp, affine_descriptor_pred)
299
+ ]
300
+
301
+ return descriptors, affine_descriptors, keypoints
utils/image_process.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+
7
+ import torch
8
+ from torchvision import transforms as T
9
+ from torch.nn import functional as F
10
+ import numpy as np
11
+ import cv2
12
+ import scipy.stats as st
13
+ # import config
14
+ from tqdm import tqdm
15
+ import matplotlib.pyplot as plt
16
+
17
+
18
+ def get_gaussian_kernel(kernlen=21, nsig=5):
19
+ """
20
+ 生成二维高斯核并返回一个不可训练的核权重,用于创建高斯热图
21
+
22
+ 参数:
23
+ kernlen (int): 核的大小(边长),必须是奇数
24
+ nsig (float): 高斯分布的标准差(控制分布的宽度)
25
+
26
+ 返回:
27
+ torch.Tensor: 形状为(1, 1, kernlen, kernlen)的高斯核张量
28
+ """
29
+ # 1. 计算采样间隔
30
+ # 在[-nsig, nsig]范围内均匀采样kernlen个点
31
+ interval = (2 * nsig + 1.) / kernlen
32
+ # 2. 创建一维坐标点数组
33
+ # 从-nsig-interval/2到nsig+interval/2,共kernlen+1个点
34
+ x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
35
+ # 3. 计算一维高斯核
36
+ # 使用标准正态分布的累积分布函数(CDF)的差分
37
+ kern1d = np.diff(st.norm.cdf(x))
38
+ # 4. 创建二维高斯核
39
+ # 通过两个一维高斯核的外积(outer product)
40
+ kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
41
+ # 5. 归一化核
42
+ # 使所有元素之和为1,确保核的总"能量"为1
43
+ kernel = kernel_raw / kernel_raw.sum()
44
+ # 6. 转换为PyTorch张量并调整维度
45
+ # 添加批次和通道维度 (1, 1, H, W)
46
+ kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
47
+ # 7. 创建不可训练的高斯核权重参数
48
+ weight = torch.nn.Parameter(data=kernel, requires_grad=False)
49
+ # 8. 将值归一化到[0, 1]范围
50
+ # 使最小值为0,最大值为1
51
+ weight = (weight - weight.min()) / (weight.max() - weight.min())
52
+
53
+ return weight
54
+
55
+
56
+ # 对用于监督的点标签进行高斯模糊
57
+ def label_gaussian_blur(label_point_positions, kernel, stride=1):
58
+ '''
59
+ 对用于监督的点标签进行高斯模糊
60
+ 该函数通过应用高斯卷积核来模糊标签点,以生成更平滑的热图
61
+ 参数:
62
+ label_point_positions: 标签点的位置,形状为(B, C, H, W)的张量
63
+ kernel: 高斯卷积核,形状为(1, 1, Hk, Wk)的张量
64
+ stride: 卷积步长,默认为1
65
+ 返回值:
66
+ blurred_label: 模糊后的标签热图,形状与输入相同
67
+ '''
68
+ # 应用高斯卷积进行模糊
69
+ blurred_label = F.conv2d(label_point_positions, kernel, stride=stride, padding=(kernel.shape[-1] - 1) // 2)
70
+ # 裁剪值范围:限制热图最大值为1
71
+ blurred_label[blurred_label > 1] = 1
72
+
73
+ return blurred_label
74
+
75
+
76
+ def affine_images(images, used_for='detector'):
77
+ '''
78
+ 根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强
79
+ 参数:
80
+ images: 输入图像张量,形状为 (B, C, H, W)
81
+ used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器
82
+ 返回值:
83
+ CPU上的仿射变换后的图像 output.detach().clone()
84
+ 仿射变换网格 grid
85
+ 逆仿射变换网格 grid_inverse
86
+ '''
87
+
88
+ """
89
+ Perform affine transformation on images
90
+ :param images: (B, C, H, W)
91
+ :param keypoint_labels: corresponding labels
92
+ :param value_map: value maps, used to record history learned geo_points
93
+ :return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse
94
+ """
95
+
96
+ h, w = images.shape[2:] # 获取图像的尺寸
97
+ theta = None
98
+ thetaI = None
99
+
100
+ # 对一个batch中的每个图像进行随机仿射变换参数的生成
101
+ for i in range(len(images)):
102
+ if used_for == 'detector':
103
+ affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15], # 旋转角度范围:±15度
104
+ translate=[0.2, 0.2], # 平移范围:±20%
105
+ scale_ranges=[0.9, 1.35], # 缩放范围:0.9-1.35倍
106
+ shears=None, # 不使用剪切变换
107
+ img_size=[h, w])
108
+ else:
109
+ affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3], # 旋转角度范围:±3度
110
+ translate=[0.1, 0.1], # 平移范围:宽高的10%
111
+ scale_ranges=[0.9, 1.1], # 缩放范围:0.9到1.1倍
112
+ shears=None, # 不使用剪切变换
113
+ img_size=[h, w])
114
+ # 根据仿射变换参数计算变换矩阵和逆变换矩阵
115
+ angle = -affine_params[0] * math.pi / 180
116
+ theta_ = torch.tensor([
117
+ [1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]],
118
+ [math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]],
119
+ [0, 0, 1]
120
+ ], dtype=torch.float).to(images)
121
+ thetaI_ = theta_.inverse()
122
+ theta_ = theta_[:2]
123
+ thetaI_ = thetaI_[:2]
124
+
125
+ # 将变换矩阵和逆变换矩阵叠成一个batch
126
+ theta_ = theta_.unsqueeze(0)
127
+ thetaI_ = thetaI_.unsqueeze(0)
128
+ theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了
129
+ thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_))
130
+
131
+ # 根据变换矩阵生成网格
132
+ # 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。
133
+ grid = F.affine_grid(theta, images.size(), align_corners=True)
134
+ grid = grid.to(images) # 将参数移动到GPU上
135
+ grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True)
136
+ grid_inverse = grid_inverse.to(images) # 将参数移动到GPU上
137
+ output = F.grid_sample(images, grid, align_corners=True) # 对图像进行采样得到仿射变换后的图像
138
+
139
+ # 对于用于描述符的情况进行一些额外的随机增强操作
140
+ if used_for == 'descriptor':
141
+ if random.random() >= 0.4:
142
+ # 颜色抖动(其实是灰度变化)
143
+ output = output.repeat(1, 3, 1, 1) # 将单通道图像复制为三通道(颜色变换需要RGB)
144
+ output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output)
145
+ output = T.Grayscale()(output) # 灰度化
146
+
147
+ return output.detach().clone(), grid, grid_inverse
148
+
149
+
150
+ def affine_images_with_mask(images, used_for='detector'):
151
+ '''
152
+ 仿射变换图像后,额外返回一个公共区域mask
153
+ 根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强
154
+ 参数:
155
+ images: 输入图像张量,形状为 (B, C, H, W)
156
+ used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器
157
+ 返回值:
158
+ CPU上的仿射变换后的图像 output.detach().clone()
159
+ 仿射变换网格 grid
160
+ 逆仿射变换网格 grid_inverse
161
+ '''
162
+
163
+ """
164
+ Perform affine transformation on images
165
+ :param images: (B, C, H, W)
166
+ :param keypoint_labels: corresponding labels
167
+ :param value_map: value maps, used to record history learned geo_points
168
+ :return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse
169
+ """
170
+
171
+ h, w = images.shape[2:] # 获取图像的尺寸
172
+ theta = None
173
+ thetaI = None
174
+
175
+ # ---------------------------------------------------------------------
176
+ # 新增: 创建原始图像坐标网格
177
+ base_grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), (1, 1, h, w), align_corners=True)
178
+ base_grid = base_grid.expand(images.size(0), *base_grid.shape[1:])
179
+ # ---------------------------------------------------------------------
180
+
181
+ # 对一个batch中的每个图像进行随机仿射变换参数的生成
182
+ for i in range(len(images)):
183
+ if used_for == 'detector':
184
+ affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15], # 旋转角度范围:±15度
185
+ translate=[0.2, 0.2], # 平移范围:±20%
186
+ scale_ranges=[0.9, 1.35], # 缩放范围:0.9-1.35倍
187
+ shears=None, # 不使用剪切变换
188
+ img_size=[h, w])
189
+ else:
190
+ affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3], # 旋转角度范围:±3度
191
+ translate=[0.1, 0.1], # 平移范围:宽高的10%
192
+ scale_ranges=[0.9, 1.1], # 缩放范围:0.9到1.1倍
193
+ shears=None, # 不使用剪切变换
194
+ img_size=[h, w])
195
+ # 根据仿射变换参数计算变换矩阵和逆变换矩阵
196
+ angle = -affine_params[0] * math.pi / 180
197
+ theta_ = torch.tensor([
198
+ [1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]],
199
+ [math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]],
200
+ [0, 0, 1]
201
+ ], dtype=torch.float).to(images)
202
+ thetaI_ = theta_.inverse()
203
+ theta_ = theta_[:2]
204
+ thetaI_ = thetaI_[:2]
205
+
206
+ # 将变换矩阵和逆变换矩阵叠成一个batch
207
+ theta_ = theta_.unsqueeze(0)
208
+ thetaI_ = thetaI_.unsqueeze(0)
209
+ theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了
210
+ thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_))
211
+
212
+ # 根据变换矩阵生成网格
213
+ # 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。
214
+ grid = F.affine_grid(theta, images.size(), align_corners=True)
215
+ grid = grid.to(images) # 将参数移动到GPU上
216
+ grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True)
217
+ grid_inverse = grid_inverse.to(images) # 将参数移动到GPU上
218
+ output = F.grid_sample(images, grid, align_corners=True) # 对图像进行采样得到仿射变换后的图像
219
+
220
+ # ---------------------------------------------------------------------
221
+ # 新增: 计算有效区域mask
222
+ # 1. 创建全1mask (表示所有像素初始有效)
223
+ valid_mask = torch.ones_like(images[:, :1]) # (B, 1, H, W)
224
+ # 2. 应用相同的变换到mask上
225
+ transformed_mask = F.grid_sample(valid_mask, grid, align_corners=True)
226
+ # 3. 反向映射回原始图像坐标
227
+ valid_mask = F.grid_sample(transformed_mask, grid_inverse, align_corners=True)
228
+ # 4. 二值化 (大于0.5视为有效)
229
+ valid_mask = (valid_mask > 0.5).float()
230
+ # ---------------------------------------------------------------------
231
+
232
+ # 对于用于描述符的情况进行一些额外的随机增强操作
233
+ if used_for == 'descriptor':
234
+ if random.random() >= 0.4:
235
+ # 颜色抖动(其实是灰度变化)
236
+ output = output.repeat(1, 3, 1, 1) # 将单通道图像复制为三通道(颜色变换需要RGB)
237
+ output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output)
238
+ output = T.Grayscale()(output) # 灰度化
239
+
240
+ return output.detach().clone(), grid, grid_inverse, valid_mask.detach().clone()
241
+
242
+
243
+ def pre_processing(data: np.ndarray) -> np.ndarray:
244
+ """
245
+ 视网膜图像预处理流水线
246
+ 包含标准化、CLAHE对比度增强和伽马校正
247
+
248
+ 参数:
249
+ image: 输入视网膜图像 (单通道,numpy数组)
250
+
251
+ 返回:
252
+ 预处理后的图像 (float32, 范围[0,1])
253
+ """
254
+ train_imgs = datasets_normalized(data)
255
+ train_imgs = clahe_equalized(train_imgs)
256
+ train_imgs = adjust_gamma(train_imgs, 1.2)
257
+ train_imgs = train_imgs / 255. # 最终归一化到[0,1]范围
258
+ return train_imgs.astype(np.float32)
259
+
260
+
261
+ def datasets_normalized(images: np.ndarray) -> np.ndarray:
262
+ # 归一化之后还需要把值映射到0到255
263
+ # images_normalized = np.empty(images.shape)
264
+ # 计算全局统计量
265
+ images_std = np.std(images)
266
+ images_mean = np.mean(images)
267
+ # 应用标准化: (x - mean) / std
268
+ # 添加微小值避免除以零
269
+ images_normalized = (images - images_mean) / (images_std + 1e-6)
270
+ # 线性映射到[0,255]范围
271
+ minv = np.min(images_normalized)
272
+ images_normalized = ((images_normalized - minv) / (np.max(images_normalized) - minv)) * 255
273
+
274
+ return images_normalized
275
+
276
+
277
+ def clahe_equalized(images):
278
+ # 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
279
+ # 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
280
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
281
+ # clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
282
+ # tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
283
+ images_equalized = np.empty(images.shape)
284
+ images_equalized[:, :] = clahe.apply(np.array(images[:, :], dtype=np.uint8))
285
+
286
+ return images_equalized
287
+
288
+
289
+ def adjust_gamma(images, gamma=1.0):
290
+ # γ>1: 压缩高光区域,增强暗部细节 (适合视网膜图像)
291
+ # 创建伽马校正查找表
292
+ invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。
293
+ table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
294
+ # 预计算伽马校正的转换值,用于快速查找。
295
+ # 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
296
+ # 生成过程:
297
+ # i / 255.0: 将像素值归一化到 [0, 1] 范围。
298
+ # (i / 255.0) ** invGamma: 应用伽马校正公式。
299
+ # * 255: 将归一化后的值还原到 [0, 255] 范围。
300
+ # astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
301
+ new_images = np.empty(images.shape)
302
+ # 应用伽马校正
303
+ new_images[:, :] = cv2.LUT(np.array(images[:, :], dtype=np.uint8), table)
304
+ # cv2.LUT: OpenCV 的快速像素值映射函数。
305
+ # 输入图像的每个像素值通过查找表 table 进行伽马校正。
306
+ # 大幅提高效率,避免逐像素计算。
307
+
308
+ return new_images
309
+
310
+
311
+ def simple_nms(scores, nms_radius: int):
312
+ """
313
+ 快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
314
+ 参数:
315
+ scores: 关键点分数图 (B, H, W) 或 (H, W)
316
+ nms_radius: NMS邻域半径
317
+ 返回:
318
+ NMS处理后的分数图
319
+ """
320
+ assert (nms_radius >= 0)
321
+ # 计算NMS窗口大小 (2*半径+1)
322
+ size = nms_radius * 2 + 1
323
+ avg_size = 2
324
+ # 定义最大池化函数 (使用固定步长1和适当填充)
325
+ def max_pool(x):
326
+ return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
327
+
328
+ # 创建与输入相同形状的零张量
329
+ zeros = torch.zeros_like(scores)
330
+ # max_map = max_pool(scores)
331
+
332
+ # 步骤1: 识别局部最大值点
333
+ # 比较每个点与其邻域内的最大值
334
+ max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
335
+ # 步骤2: 添加微小随机扰动 (避免多个相同最大值)
336
+ # 生成 [0, 0.1) 范围内的随机数
337
+ max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10
338
+ # 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
339
+ # 非局部最大值点置零
340
+ max_mask_[~max_mask] = 0
341
+ # 步骤3: 对扰动后的图再次应用NMS
342
+ # 识别扰动后仍然是局部最大值的点
343
+ mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。
344
+ # 步骤4: 保留局部最大值点,其他点置零
345
+ return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。
346
+
347
+
348
+ def remove_borders(keypoints, scores, border: int, height: int, width: int):
349
+ '''
350
+ 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
351
+ 这个函数是在推理的时候用的,训练的时候没有使用
352
+ keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
353
+ scores: 每个关键点对应的分数,形状为 (N,)。
354
+ border: 表示需要移除的边界宽度。 推理时预设的是4像素
355
+ height: 图像的高度。
356
+ width: 图像的宽度。
357
+ '''
358
+ mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
359
+ mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
360
+ mask = mask_h & mask_w # 组合掩码 (必须同时满足高度和宽度条件)
361
+ return keypoints[mask], scores[mask] # 返回过滤后的关键点和分数
362
+
363
+
364
+ def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
365
+ """
366
+ 在检测器预测上应用非极大值抑制 (NMS)
367
+
368
+ 参数:
369
+ detector_pred: 检测器预测 (B, 1, H, W)
370
+ nms_thresh: NMS阈值
371
+ nms_size: NMS邻域大小
372
+ detector_label: 检测器标签 (可选)
373
+ mask: 是否使用标签掩码 (当前未实现)
374
+
375
+ 返回:
376
+ 关键点位置列表 (每个元素是 (N, 2) 的数组)
377
+ """
378
+ detector_pred = detector_pred.clone().detach() # 创建预测副本 (避免修改原始数据)
379
+ B, _, h, w = detector_pred.shape # 获取批次大小和图像尺寸
380
+ # if mask:
381
+ # assert detector_label is not None
382
+ # detector_pred[detector_pred < nms_thresh] = 0
383
+ # label_mask = detector_label
384
+ #
385
+ # # more area
386
+ #
387
+ # detector_label = detector_label.long().cpu().numpy()
388
+ # detector_label = detector_label.astype(np.uint8)
389
+ # kernel = np.ones((3, 3), np.uint8)
390
+ # label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
391
+ # for s in range(len(detector_label))])
392
+ # label_mask = torch.from_numpy(label_mask).unsqueeze(1)
393
+ # detector_pred[label_mask > 1e-6] = 0
394
+ scores = simple_nms(detector_pred, nms_size) # 应用快速NMS算法
395
+ scores = scores.reshape(B, h, w) # 重塑分数图形状 (B, H, W)
396
+ # print(f"scores before thresh {nms_thresh}", scores)
397
+ points = [torch.nonzero(s > nms_thresh) for s in scores] # 找出分数高于阈值的点
398
+ # print("points after thresh", points)
399
+ scores = [s[tuple(k.t())] for s, k in zip(scores, points)] # 提取这些点的分数值
400
+ points, scores = list(zip(*[ remove_borders(k, s, 8, h, w) for k, s in zip(points, scores)])) # 移除靠近边界的点
401
+ points = [torch.flip(k, [1]).long() for k in points] # 翻转坐标顺序: [y, x] -> [x, y]
402
+
403
+ return points