Spaces:
Runtime error
Runtime error
| import torch | |
| import math | |
| def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): | |
| loss = 0 | |
| object_number = len(bboxes) | |
| if object_number == 0: | |
| return torch.tensor(0).float().cuda() | |
| for attn_map_integrated in attn_maps_mid: | |
| attn_map = attn_map_integrated.chunk(2)[1] | |
| # | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| for obj_idx in range(object_number): | |
| obj_loss = 0 | |
| mask = torch.zeros(size=(H, W)).cuda() | |
| for obj_box in bboxes[obj_idx]: | |
| x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
| int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
| mask[y_min: y_max, x_min: x_max] = 1 | |
| for obj_position in object_positions[obj_idx]: | |
| ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
| activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) | |
| obj_loss += torch.mean((1 - activation_value) ** 2) | |
| loss += (obj_loss/len(object_positions[obj_idx])) | |
| # compute loss on padding tokens | |
| # activation_value = torch.zeros(size=(b, )).cuda() | |
| # for obj_idx in range(object_number): | |
| # bbox = bboxes[obj_idx] | |
| # ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1) | |
| # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
| # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
| # | |
| # loss += torch.mean((1 - activation_value) ** 2) | |
| for attn_map_integrated in attn_maps_up[0]: | |
| attn_map = attn_map_integrated.chunk(2)[1] | |
| # | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| for obj_idx in range(object_number): | |
| obj_loss = 0 | |
| mask = torch.zeros(size=(H, W)).cuda() | |
| for obj_box in bboxes[obj_idx]: | |
| x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
| int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
| mask[y_min: y_max, x_min: x_max] = 1 | |
| for obj_position in object_positions[obj_idx]: | |
| ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
| # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) | |
| activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( | |
| dim=-1) | |
| obj_loss += torch.mean((1 - activation_value) ** 2) | |
| loss += (obj_loss / len(object_positions[obj_idx])) | |
| # compute loss on padding tokens | |
| # activation_value = torch.zeros(size=(b, )).cuda() | |
| # for obj_idx in range(object_number): | |
| # bbox = bboxes[obj_idx] | |
| # ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1) | |
| # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
| # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
| # | |
| # loss += torch.mean((1 - activation_value) ** 2) | |
| loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) | |
| return loss |