import torch import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from just_time_windows.Actor.actor import Actor def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True): device = actor.device actor.train_mode() actor.train() actor_output = actor(batch) actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs'] with torch.no_grad(): baseline.greedy_search() baseline_output = baseline(batch) baseline_cost = baseline_output['total_time'] loss = ((actor_cost - baseline_cost).detach() * log_probs).mean() optimizer.zero_grad() loss.backward() if gradient_clipping: for group in optimizer.param_groups: params = [p for p in group['params'] if p.grad is not None] if params: clip_grad_norm_(params, max_norm=1, norm_type=2) optimizer.step() if compute_cost_ratio: if comparison_model is None: normalize = actor.apply_normalization comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device) with torch.no_grad(): comp_output = comparison_model(batch) comp_cost = comp_output['total_time'] a = comp_cost.sum().item() b = actor_cost.sum().item() return b / a return None