Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------------ | |
| # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport | |
| # Copyright (c) 2024 Borui Zhang. All Rights Reserved. | |
| # Licensed under the MIT License [see LICENSE for details] | |
| # ------------------------------------------------------------------------------ | |
| import numpy as np | |
| import torch | |
| from pytorch_fid.inception import InceptionV3 | |
| from pytorch_fid.fid_score import calculate_frechet_distance | |
| class FIDMetric: | |
| def __init__(self, device, dims=2048): | |
| self.device = device | |
| self.num_workers = 32 | |
| block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |
| self.model = InceptionV3([block_idx]).to(device) | |
| self.model.eval() | |
| self.reset_metrics() | |
| def reset_metrics(self): | |
| self.x_pred = [] | |
| self.x_rec_pred = [] | |
| def get_activates(self, x: torch.Tensor): | |
| pred = self.model(x)[0] | |
| # If model output is not scalar, apply global spatial average pooling. | |
| # This happens if you choose a dimensionality not equal 2048. | |
| if pred.size(2) != 1 or pred.size(3) != 1: | |
| pred = torch.nn.functional.adaptive_avg_pool2d(pred, output_size=(1, 1)) | |
| return pred.squeeze().cpu().numpy() | |
| def update(self, x: torch.Tensor, x_rec: torch.Tensor): | |
| """ | |
| Args: | |
| x (torch.Tensor): input tensor range from 0 to 1 | |
| x_rec (torch.Tensor): reconstructed tensor range from 0 to 1 | |
| """ | |
| self.x_pred.append(self.get_activates(x)) | |
| self.x_rec_pred.append(self.get_activates(x_rec)) | |
| def result(self): | |
| assert len(self.x_pred) != 0, "No data to compute FID" | |
| x = np.concatenate(self.x_pred, axis=0) | |
| x_rec = np.concatenate(self.x_rec_pred, axis=0) | |
| x_mu = np.mean(x, axis=0) | |
| x_sigma = np.cov(x, rowvar=False) | |
| x_rec_mu = np.mean(x_rec, axis=0) | |
| x_rec_sigma = np.cov(x_rec, rowvar=False) | |
| fid_score = calculate_frechet_distance(x_mu, x_sigma, x_rec_mu, x_rec_sigma) | |
| self.reset_metrics() | |
| return fid_score |