|
|
import logging
|
|
|
import time
|
|
|
import os
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import math
|
|
|
import scipy.io
|
|
|
import scipy.stats
|
|
|
from sklearn.impute import SimpleImputer
|
|
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
|
|
from sklearn.metrics import mean_squared_error
|
|
|
from scipy.optimize import curve_fit
|
|
|
import joblib
|
|
|
|
|
|
import seaborn as sns
|
|
|
import matplotlib.pyplot as plt
|
|
|
import copy
|
|
|
import argparse
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torch.optim as optim
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
from torch.optim.swa_utils import AveragedModel, SWALR
|
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
from data_processing import split_train_test
|
|
|
|
|
|
|
|
|
import warnings
|
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
|
|
|
super().__init__()
|
|
|
self.fc1 = nn.Linear(input_features, hidden_features)
|
|
|
|
|
|
self.act1 = act_layer()
|
|
|
self.drop1 = nn.Dropout(drop_rate)
|
|
|
self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
|
|
|
self.act2 = act_layer()
|
|
|
self.drop2 = nn.Dropout(drop_rate)
|
|
|
self.fc3 = nn.Linear(hidden_features // 2, out_features)
|
|
|
|
|
|
def forward(self, input_feature):
|
|
|
x = self.fc1(input_feature)
|
|
|
|
|
|
x = self.act1(x)
|
|
|
x = self.drop1(x)
|
|
|
x = self.fc2(x)
|
|
|
x = self.act2(x)
|
|
|
x = self.drop2(x)
|
|
|
output = self.fc3(x)
|
|
|
return output
|
|
|
|
|
|
|
|
|
class MAEAndRankLoss(nn.Module):
|
|
|
def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
|
|
|
super(MAEAndRankLoss, self).__init__()
|
|
|
self.l1_w = l1_w
|
|
|
self.rank_w = rank_w
|
|
|
self.margin = margin
|
|
|
self.use_margin = use_margin
|
|
|
|
|
|
def forward(self, y_pred, y_true):
|
|
|
|
|
|
l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
|
|
|
|
|
|
n = y_pred.size(0)
|
|
|
pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
|
|
|
true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
|
|
|
|
|
|
|
|
|
masks = torch.sign(true_diff)
|
|
|
|
|
|
if self.use_margin and self.margin > 0:
|
|
|
true_diff = true_diff.abs() - self.margin
|
|
|
true_diff = F.relu(true_diff)
|
|
|
masks = true_diff.sign()
|
|
|
|
|
|
l_rank = F.relu(true_diff - masks * pred_diff)
|
|
|
l_rank = l_rank.sum() / (n * (n - 1))
|
|
|
|
|
|
loss = l_mae + l_rank * self.rank_w
|
|
|
return loss
|
|
|
|
|
|
def load_data(csv_file, mat_file, features, data_name, set_name):
|
|
|
try:
|
|
|
df = pd.read_csv(csv_file, skiprows=[], header=None)
|
|
|
except Exception as e:
|
|
|
logging.error(f'Read CSV file error: {e}')
|
|
|
raise
|
|
|
|
|
|
try:
|
|
|
if data_name == 'lsvq_train':
|
|
|
X_mat = features
|
|
|
else:
|
|
|
X_mat = scipy.io.loadmat(mat_file)
|
|
|
except Exception as e:
|
|
|
logging.error(f'Read MAT file error: {e}')
|
|
|
raise
|
|
|
|
|
|
y_data = df.values[1:, 2]
|
|
|
y = np.array(list(y_data), dtype=float)
|
|
|
|
|
|
if data_name == 'cross_dataset':
|
|
|
y[y > 5] = 5
|
|
|
if set_name == 'test':
|
|
|
print(f"Modified y_true: {y}")
|
|
|
if data_name == 'lsvq_train':
|
|
|
X = np.asarray(X_mat, dtype=float)
|
|
|
else:
|
|
|
data_name = f'{data_name}_{set_name}_features'
|
|
|
X = np.asarray(X_mat[data_name], dtype=float)
|
|
|
|
|
|
return X, y
|
|
|
|
|
|
def preprocess_data(X, y):
|
|
|
X[np.isnan(X)] = 0
|
|
|
X[np.isinf(X)] = 0
|
|
|
imp = SimpleImputer(missing_values=np.nan, strategy='mean').fit(X)
|
|
|
X = imp.transform(X)
|
|
|
|
|
|
|
|
|
scaler = MinMaxScaler().fit(X)
|
|
|
X = scaler.transform(X)
|
|
|
logging.info(f'Scaler: {scaler}')
|
|
|
|
|
|
y = y.reshape(-1, 1).squeeze()
|
|
|
return X, y, imp, scaler
|
|
|
|
|
|
|
|
|
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
|
|
|
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
|
|
|
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
|
|
|
return yhat
|
|
|
|
|
|
def fit_logistic_regression(y_pred, y_true):
|
|
|
beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
|
|
|
popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
|
|
|
y_pred_logistic = logistic_func(y_pred, *popt)
|
|
|
return y_pred_logistic, beta, popt
|
|
|
|
|
|
def compute_correlation_metrics(y_true, y_pred):
|
|
|
y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
|
|
|
|
|
|
plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
|
|
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
|
|
|
srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
|
|
|
|
|
|
try:
|
|
|
krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
|
|
|
except Exception as e:
|
|
|
logging.error(f'krcc calculation: {e}')
|
|
|
krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
|
|
|
return y_pred_logistic, plcc, rmse, srcc, krcc
|
|
|
|
|
|
def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
|
|
|
|
|
|
mos1 = y_test
|
|
|
y1 = y_test_pred_logistic
|
|
|
|
|
|
try:
|
|
|
beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
|
|
|
popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
|
|
|
sigma = np.sqrt(np.diag(pcov))
|
|
|
except:
|
|
|
raise Exception('Fitting logistic function time-out!!')
|
|
|
x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
|
|
|
plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
|
|
|
|
|
|
fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
|
|
|
plt.legend(loc='upper left')
|
|
|
if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train':
|
|
|
plt.ylim(0, 100)
|
|
|
plt.xlim(0, 100)
|
|
|
else:
|
|
|
plt.ylim(1, 5)
|
|
|
plt.xlim(1, 5)
|
|
|
plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
|
|
|
plt.xlabel('Predicted Score')
|
|
|
plt.ylabel('MOS')
|
|
|
reg_fig1 = fig1.get_figure()
|
|
|
|
|
|
fig_path = f'../figs/{data_name}/'
|
|
|
os.makedirs(fig_path, exist_ok=True)
|
|
|
reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_by{select_criteria}.png", dpi=300)
|
|
|
plt.clf()
|
|
|
plt.close()
|
|
|
|
|
|
def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
|
plt.plot(avg_train_losses, label='Average Training Loss')
|
|
|
plt.plot(avg_val_losses, label='Average Validation Loss')
|
|
|
|
|
|
plt.xlabel('Epoch')
|
|
|
plt.ylabel('Loss')
|
|
|
plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
|
|
|
|
|
|
plt.legend()
|
|
|
fig_par_path = f'../log/result/{data_name}/'
|
|
|
os.makedirs(fig_par_path, exist_ok=True)
|
|
|
plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
|
|
|
plt.clf()
|
|
|
plt.close()
|
|
|
|
|
|
def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
|
|
|
log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_corr_{select_criteria}.log")
|
|
|
logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
|
|
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
|
logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
|
|
|
logging.info(f"torch cuda: {torch.cuda.is_available()}")
|
|
|
|
|
|
def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
|
|
|
if data_name == 'cross_dataset':
|
|
|
data_name1 = 'youtube_ugc_all'
|
|
|
data_name2 = 'cvd_2014_all'
|
|
|
csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name1}_MOS_train.csv')
|
|
|
csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name2}_MOS_test.csv')
|
|
|
mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name1}_{network_name}_train_features.mat')
|
|
|
mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name2}_{network_name}_test_features.mat')
|
|
|
X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name1, 'train')
|
|
|
X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name2, 'test')
|
|
|
|
|
|
elif data_name == 'lsvq_train':
|
|
|
csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
|
|
|
csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
|
|
|
X_train, y_train = load_data(csv_train_file, None, train_features, data_name, 'train')
|
|
|
X_test, y_test = load_data(csv_test_file, None, test_features, data_name, 'test')
|
|
|
|
|
|
else:
|
|
|
csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
|
|
|
csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
|
|
|
mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_train_features.mat')
|
|
|
mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_test_features.mat')
|
|
|
X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name, 'train')
|
|
|
X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name, 'test')
|
|
|
|
|
|
|
|
|
X_train, y_train, _, _ = preprocess_data(X_train, y_train)
|
|
|
X_test, y_test, _, _ = preprocess_data(X_test, y_test)
|
|
|
|
|
|
return X_train, y_train, X_test, y_test
|
|
|
|
|
|
def train_one_epoch(model, train_loader, criterion, optimizer, device):
|
|
|
"""Train the model for one epoch"""
|
|
|
model.train()
|
|
|
train_loss = 0.0
|
|
|
for inputs, targets in train_loader:
|
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
outputs = model(inputs)
|
|
|
loss = criterion(outputs, targets.view(-1, 1))
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
train_loss += loss.item() * inputs.size(0)
|
|
|
train_loss /= len(train_loader.dataset)
|
|
|
return train_loss
|
|
|
|
|
|
def evaluate(model, val_loader, criterion, device):
|
|
|
"""Evaluate model performance on validation sets"""
|
|
|
model.eval()
|
|
|
val_loss = 0.0
|
|
|
y_val_pred = []
|
|
|
with torch.no_grad():
|
|
|
for inputs, targets in val_loader:
|
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
|
|
|
|
outputs = model(inputs)
|
|
|
y_val_pred.extend(outputs.view(-1).tolist())
|
|
|
loss = criterion(outputs, targets.view(-1, 1))
|
|
|
val_loss += loss.item() * inputs.size(0)
|
|
|
val_loss /= len(val_loader.dataset)
|
|
|
return val_loss, np.array(y_val_pred)
|
|
|
|
|
|
def update_best_model(select_criteria, best_metric, current_val, model):
|
|
|
is_better = False
|
|
|
if select_criteria == 'byrmse' and current_val < best_metric:
|
|
|
is_better = True
|
|
|
elif select_criteria == 'bykrcc' and current_val > best_metric:
|
|
|
is_better = True
|
|
|
|
|
|
if is_better:
|
|
|
return current_val, copy.deepcopy(model), is_better
|
|
|
return best_metric, model, is_better
|
|
|
|
|
|
def train_and_evaluate(X_train, y_train, config):
|
|
|
|
|
|
n_repeats = config['n_repeats']
|
|
|
batch_size = config['batch_size']
|
|
|
epochs = config['epochs']
|
|
|
hidden_features = config['hidden_features']
|
|
|
drop_rate = config['drop_rate']
|
|
|
loss_type = config['loss_type']
|
|
|
optimizer_type = config['optimizer_type']
|
|
|
select_criteria = config['select_criteria']
|
|
|
initial_lr = config['initial_lr']
|
|
|
weight_decay = config['weight_decay']
|
|
|
patience = config['patience']
|
|
|
l1_w = config['l1_w']
|
|
|
rank_w = config['rank_w']
|
|
|
use_swa = config.get('use_swa', False)
|
|
|
logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Batch size: {batch_size}, Number of epochs: {epochs}')
|
|
|
logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
|
|
|
logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
|
|
|
logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
|
|
|
|
|
|
|
|
|
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
|
|
|
|
|
|
best_model = None
|
|
|
best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
|
|
|
|
|
|
|
|
|
all_train_losses = []
|
|
|
all_val_losses = []
|
|
|
|
|
|
|
|
|
model = Mlp(input_features=X_train.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
|
|
|
model = model.to(device)
|
|
|
|
|
|
if loss_type == 'MAERankLoss':
|
|
|
criterion = MAEAndRankLoss()
|
|
|
criterion.l1_w = l1_w
|
|
|
criterion.rank_w = rank_w
|
|
|
else:
|
|
|
nn.MSELoss()
|
|
|
|
|
|
if optimizer_type == 'sgd':
|
|
|
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
|
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
|
|
|
else:
|
|
|
optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay)
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95)
|
|
|
if use_swa:
|
|
|
swa_model = AveragedModel(model).to(device)
|
|
|
swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
|
|
|
|
|
|
|
|
|
train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))
|
|
|
val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val))
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
|
|
|
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
|
|
|
|
|
|
train_losses, val_losses = [], []
|
|
|
|
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
epochs_no_improve = 0
|
|
|
early_stop_active = False
|
|
|
swa_start = int(epochs * 0.7) if use_swa else epochs
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
|
|
|
train_losses.append(train_loss)
|
|
|
scheduler.step()
|
|
|
if use_swa and epoch >= swa_start:
|
|
|
swa_model.update_parameters(model)
|
|
|
swa_scheduler.step()
|
|
|
early_stop_active = True
|
|
|
print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
|
|
|
|
|
|
lr = optimizer.param_groups[0]['lr']
|
|
|
print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
|
|
|
|
|
|
|
|
|
current_model = swa_model if use_swa and epoch >= swa_start else model
|
|
|
current_model.eval()
|
|
|
val_loss, y_val_pred = evaluate(current_model, val_loader, criterion, device)
|
|
|
val_losses.append(val_loss)
|
|
|
print(f"Epoch {epoch + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
|
|
|
|
|
|
y_val_pred = np.array(list(y_val_pred), dtype=float)
|
|
|
_, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val, y_val_pred)
|
|
|
current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
|
|
|
best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
|
|
|
if is_better:
|
|
|
logging.info(f"Epoch {epoch + 1}:")
|
|
|
y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val, y_val_pred)
|
|
|
logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
|
|
|
|
|
|
X_train_fold_tensor = torch.FloatTensor(X_train).to(device)
|
|
|
y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().numpy().squeeze()
|
|
|
y_tra_pred_tmp = np.array(list(y_tra_pred_tmp), dtype=float)
|
|
|
y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train, y_tra_pred_tmp)
|
|
|
logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
|
|
|
|
|
|
|
|
|
if early_stop_active:
|
|
|
if val_loss < best_val_loss:
|
|
|
best_val_loss = val_loss
|
|
|
|
|
|
best_model = copy.deepcopy(model)
|
|
|
epochs_no_improve = 0
|
|
|
else:
|
|
|
epochs_no_improve += 1
|
|
|
if epochs_no_improve >= patience:
|
|
|
|
|
|
print(f"Early stopping triggered after {epoch + 1} epochs.")
|
|
|
break
|
|
|
|
|
|
|
|
|
if use_swa:
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_to_device(x, device))
|
|
|
best_model = best_model.to(device)
|
|
|
best_model.eval()
|
|
|
torch.optim.swa_utils.update_bn(train_loader, best_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_train_losses.append(train_losses)
|
|
|
all_val_losses.append(val_losses)
|
|
|
max_length = max(len(x) for x in all_train_losses)
|
|
|
all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
|
|
|
max_length = max(len(x) for x in all_val_losses)
|
|
|
all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
|
|
|
|
|
|
return best_model, all_train_losses, all_val_losses
|
|
|
|
|
|
def collate_to_device(batch, device):
|
|
|
data, targets = zip(*batch)
|
|
|
return torch.stack(data).to(device), torch.stack(targets).to(device)
|
|
|
|
|
|
def model_test(best_model, X, y, device):
|
|
|
test_dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y))
|
|
|
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
|
|
|
|
|
|
best_model.eval()
|
|
|
y_pred = []
|
|
|
with torch.no_grad():
|
|
|
for inputs, _ in test_loader:
|
|
|
inputs = inputs.to(device)
|
|
|
|
|
|
outputs = best_model(inputs)
|
|
|
y_pred.extend(outputs.view(-1).tolist())
|
|
|
|
|
|
return y_pred
|
|
|
|
|
|
def main(config):
|
|
|
model_name = config['model_name']
|
|
|
data_name = config['data_name']
|
|
|
network_name = config['network_name']
|
|
|
|
|
|
metadata_path = config['metadata_path']
|
|
|
feature_path = config['feature_path']
|
|
|
log_path = config['log_path']
|
|
|
save_path = config['save_path']
|
|
|
score_path = config['score_path']
|
|
|
result_path = config['result_path']
|
|
|
|
|
|
|
|
|
select_criteria = config['select_criteria']
|
|
|
n_repeats = config['n_repeats']
|
|
|
|
|
|
|
|
|
os.makedirs(log_path, exist_ok=True)
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
os.makedirs(score_path, exist_ok=True)
|
|
|
os.makedirs(result_path, exist_ok=True)
|
|
|
result_file = f'{result_path}{data_name}_{network_name}_{select_criteria}.mat'
|
|
|
pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{select_criteria}.csv")
|
|
|
file_path = os.path.join(save_path, f"{data_name}_{network_name}_{select_criteria}_trained_median_model_param.pth")
|
|
|
configure_logging(log_path, model_name, data_name, network_name, select_criteria)
|
|
|
|
|
|
'''======================== Main Body ==========================='''
|
|
|
PLCC_all_repeats_test = []
|
|
|
SRCC_all_repeats_test = []
|
|
|
KRCC_all_repeats_test = []
|
|
|
RMSE_all_repeats_test = []
|
|
|
PLCC_all_repeats_train = []
|
|
|
SRCC_all_repeats_train = []
|
|
|
KRCC_all_repeats_train = []
|
|
|
RMSE_all_repeats_train = []
|
|
|
all_repeats_test_vids = []
|
|
|
all_repeats_df_test_pred = []
|
|
|
best_model_list = []
|
|
|
|
|
|
for i in range(1, n_repeats + 1):
|
|
|
print(f"{i}th repeated 80-20 hold out test")
|
|
|
logging.info(f"{i}th repeated 80-20 hold out test")
|
|
|
t0 = time.time()
|
|
|
|
|
|
|
|
|
test_size = 0.2
|
|
|
random_state = math.ceil(8.8 * i)
|
|
|
|
|
|
if data_name == 'lsvq_train':
|
|
|
test_data_name = 'lsvq_test'
|
|
|
train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
|
|
|
elif data_name == 'cross_dataset':
|
|
|
train_data_name = 'youtube_ugc_all'
|
|
|
test_data_name = 'cvd_2014_all'
|
|
|
_, _, test_vids = split_train_test.process_cross_dataset(train_data_name, test_data_name, metadata_path, feature_path, network_name)
|
|
|
else:
|
|
|
_, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
|
|
|
|
|
|
'''======================== read files =============================== '''
|
|
|
if data_name == 'lsvq_train':
|
|
|
X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
|
|
|
else:
|
|
|
X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
|
|
|
|
|
|
'''======================== regression model =============================== '''
|
|
|
best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
|
|
|
|
|
|
|
|
|
avg_train_losses = np.mean(all_train_losses, axis=0)
|
|
|
avg_val_losses = np.mean(all_val_losses, axis=0)
|
|
|
test_vids = test_vids.tolist()
|
|
|
plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
|
|
|
|
|
|
|
|
|
y_train_pred = model_test(best_model, X_train, y_train, device)
|
|
|
y_train_pred = np.array(list(y_train_pred), dtype=float)
|
|
|
y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train, y_train_pred)
|
|
|
|
|
|
|
|
|
y_test_pred = model_test(best_model, X_test, y_test, device)
|
|
|
y_test_pred = np.array(list(y_test_pred), dtype=float)
|
|
|
y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test, y_test_pred)
|
|
|
|
|
|
|
|
|
test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
|
|
|
df_test_pred = pd.DataFrame(test_pred_score)
|
|
|
|
|
|
|
|
|
logging.info("============================================================================================================")
|
|
|
SRCC_all_repeats_test.append(srcc_test)
|
|
|
KRCC_all_repeats_test.append(krcc_test)
|
|
|
PLCC_all_repeats_test.append(plcc_test)
|
|
|
RMSE_all_repeats_test.append(rmse_test)
|
|
|
SRCC_all_repeats_train.append(srcc_train)
|
|
|
KRCC_all_repeats_train.append(krcc_train)
|
|
|
PLCC_all_repeats_train.append(plcc_train)
|
|
|
RMSE_all_repeats_train.append(rmse_train)
|
|
|
all_repeats_test_vids.append(test_vids)
|
|
|
all_repeats_df_test_pred.append(df_test_pred)
|
|
|
best_model_list.append(copy.deepcopy(best_model))
|
|
|
|
|
|
|
|
|
logging.info('Best results in Mlp model within one split')
|
|
|
logging.info(f'MODEL: {best_model}')
|
|
|
logging.info('======================================================')
|
|
|
logging.info(f'Train set - Evaluation Results')
|
|
|
logging.info(f'SRCC_train: {srcc_train}')
|
|
|
logging.info(f'KRCC_train: {krcc_train}')
|
|
|
logging.info(f'PLCC_train: {plcc_train}')
|
|
|
logging.info(f'RMSE_train: {rmse_train}')
|
|
|
logging.info('======================================================')
|
|
|
logging.info(f'Test set - Evaluation Results')
|
|
|
logging.info(f'SRCC_test: {srcc_test}')
|
|
|
logging.info(f'KRCC_test: {krcc_test}')
|
|
|
logging.info(f'PLCC_test: {plcc_test}')
|
|
|
logging.info(f'RMSE_test: {rmse_test}')
|
|
|
logging.info('======================================================')
|
|
|
logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
|
|
|
|
|
|
logging.info('')
|
|
|
SRCC_all_repeats_test = np.nan_to_num(SRCC_all_repeats_test)
|
|
|
KRCC_all_repeats_test = np.nan_to_num(KRCC_all_repeats_test)
|
|
|
PLCC_all_repeats_test = np.nan_to_num(PLCC_all_repeats_test)
|
|
|
RMSE_all_repeats_test = np.nan_to_num(RMSE_all_repeats_test)
|
|
|
SRCC_all_repeats_train = np.nan_to_num(SRCC_all_repeats_train)
|
|
|
KRCC_all_repeats_train = np.nan_to_num(KRCC_all_repeats_train)
|
|
|
PLCC_all_repeats_train = np.nan_to_num(PLCC_all_repeats_train)
|
|
|
RMSE_all_repeats_train = np.nan_to_num(RMSE_all_repeats_train)
|
|
|
logging.info('======================================================')
|
|
|
logging.info('Average training results among all repeated 80-20 holdouts:')
|
|
|
logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_train), np.std(SRCC_all_repeats_train))
|
|
|
logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_train), np.std(KRCC_all_repeats_train))
|
|
|
logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_train), np.std(PLCC_all_repeats_train))
|
|
|
logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_train), np.std(RMSE_all_repeats_train))
|
|
|
logging.info('======================================================')
|
|
|
logging.info('Average testing results among all repeated 80-20 holdouts:')
|
|
|
logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_test), np.std(SRCC_all_repeats_test))
|
|
|
logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_test), np.std(KRCC_all_repeats_test))
|
|
|
logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_test), np.std(PLCC_all_repeats_test))
|
|
|
logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_test), np.std(RMSE_all_repeats_test))
|
|
|
logging.info('======================================================')
|
|
|
logging.info('\n')
|
|
|
|
|
|
|
|
|
print('======================================================')
|
|
|
if select_criteria == 'byrmse':
|
|
|
median_metrics = np.median(RMSE_all_repeats_test)
|
|
|
indices = np.where(RMSE_all_repeats_test == median_metrics)[0]
|
|
|
select_criteria = select_criteria.replace('by', '').upper()
|
|
|
print(RMSE_all_repeats_test)
|
|
|
logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
|
|
|
elif select_criteria == 'bykrcc':
|
|
|
median_metrics = np.median(KRCC_all_repeats_test)
|
|
|
indices = np.where(KRCC_all_repeats_test == median_metrics)[0]
|
|
|
select_criteria = select_criteria.replace('by', '').upper()
|
|
|
print(KRCC_all_repeats_test)
|
|
|
logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
|
|
|
|
|
|
median_test_vids = [all_repeats_test_vids[i] for i in indices]
|
|
|
test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
|
|
|
|
|
|
|
|
|
|
|
|
median_model = None
|
|
|
if len(indices) > 0:
|
|
|
median_index = indices[0]
|
|
|
median_model = best_model_list[median_index]
|
|
|
median_model_df_test_pred = all_repeats_df_test_pred[median_index]
|
|
|
|
|
|
median_model_df_test_pred.to_csv(pred_score_filename, index=False)
|
|
|
plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
|
|
|
|
|
|
print(f'Median Metrics: {median_metrics}')
|
|
|
print(f'Indices: {indices}')
|
|
|
|
|
|
print(f'Best model: {median_model}')
|
|
|
|
|
|
logging.info(f'median test {select_criteria}: {median_metrics}')
|
|
|
logging.info(f"Indices of median metrics: {indices}")
|
|
|
|
|
|
logging.info(f'Best model predict score: {median_model_df_test_pred}')
|
|
|
logging.info(f'Best model: {median_model}')
|
|
|
|
|
|
|
|
|
|
|
|
scipy.io.savemat(result_file, mdict={'SRCC_train': np.asarray(SRCC_all_repeats_train, dtype=float), \
|
|
|
'KRCC_train': np.asarray(KRCC_all_repeats_train, dtype=float), \
|
|
|
'PLCC_train': np.asarray(PLCC_all_repeats_train, dtype=float), \
|
|
|
'RMSE_train': np.asarray(RMSE_all_repeats_train, dtype=float), \
|
|
|
'SRCC_test': np.asarray(SRCC_all_repeats_test, dtype=float), \
|
|
|
'KRCC_test': np.asarray(KRCC_all_repeats_test, dtype=float), \
|
|
|
'PLCC_test': np.asarray(PLCC_all_repeats_test, dtype=float), \
|
|
|
'RMSE_test': np.asarray(RMSE_all_repeats_test, dtype=float), \
|
|
|
f'Median_{select_criteria}': median_metrics, \
|
|
|
'Test_Videos_list': all_repeats_test_vids, \
|
|
|
'Test_videos_Median_model': test_vids, \
|
|
|
})
|
|
|
|
|
|
|
|
|
torch.save(median_model.state_dict(), file_path)
|
|
|
print(f"Model state_dict saved to {file_path}")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument('--model_name', type=str, default='Mlp')
|
|
|
parser.add_argument('--data_name', type=str, default='cvd_2014', help='konvid_1k, youtube_ugc, live_vqc, cvd_2014, lsvq_train, cross_dataset')
|
|
|
parser.add_argument('--network_name', type=str, default='relaxvqa', help='relaxvqa, {frag_name}_{network_name}_{layer_name}')
|
|
|
|
|
|
|
|
|
parser.add_argument('--metadata_path', type=str, default='../metadata/')
|
|
|
parser.add_argument('--feature_path', type=str, default='../features/')
|
|
|
parser.add_argument('--log_path', type=str, default='../log/')
|
|
|
parser.add_argument('--save_path', type=str, default='../model/')
|
|
|
parser.add_argument('--score_path', type=str, default='../log/predict_score/')
|
|
|
parser.add_argument('--result_path', type=str, default='../log/result/')
|
|
|
|
|
|
parser.add_argument('--select_criteria', type=str, default='byrmse', help='byrmse, bykrcc')
|
|
|
parser.add_argument('--n_repeats', type=int, default=21, help='Number of repeats for 80-20 hold out test')
|
|
|
parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training')
|
|
|
parser.add_argument('--epochs', type=int, default=120, help='Epochs for training')
|
|
|
parser.add_argument('--hidden_features', type=int, default=256, help='Hidden features')
|
|
|
parser.add_argument('--drop_rate', type=float, default=0.1, help='Dropout rate.')
|
|
|
|
|
|
parser.add_argument('--loss_type', type=str, default='MAERankLoss', help='MSEloss or MAERankLoss')
|
|
|
parser.add_argument('--optimizer_type', type=str, default='sgd', help='adam or sgd')
|
|
|
parser.add_argument('--initial_lr', type=float, default=1e-2, help='Initial learning rate: 1e-2')
|
|
|
parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay (L2 loss): 1e-4')
|
|
|
parser.add_argument('--patience', type=int, default=5, help='Early stopping patience.')
|
|
|
parser.add_argument('--use_swa', type=bool, default=True, help='Use Stochastic Weight Averaging')
|
|
|
parser.add_argument('--l1_w', type=float, default=0.6, help='MAE loss weight')
|
|
|
parser.add_argument('--rank_w', type=float, default=1.0, help='Rank loss weight')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
config = vars(args)
|
|
|
print(config)
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
print(device)
|
|
|
if device.type == "cuda":
|
|
|
torch.cuda.set_device(0)
|
|
|
|
|
|
main(config) |