Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import DataLoader | |
| def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95): | |
| """ | |
| Evaluate the actor on the validation set and update the baseline if performance improves. | |
| Parameters: | |
| - actor: current model being trained | |
| - baseline: model used as the performance reference | |
| - validation_set: dataset used for evaluation | |
| - record_scores: previously recorded baseline scores | |
| - batch_size: batch size for validation | |
| - threshold: (optional) threshold for improvement (not used in current implementation) | |
| Returns: | |
| - updated record_scores | |
| """ | |
| val_dataloader = DataLoader(dataset=validation_set, | |
| batch_size=batch_size, | |
| collate_fn=validation_set.collate) | |
| actor.greedy_search() | |
| actor.eval() | |
| actor_scores = [] | |
| for batch in val_dataloader: | |
| with torch.no_grad(): | |
| actor_output = actor(batch) | |
| actor_cost = actor_output['total_time'].view(-1) | |
| actor_scores.append(actor_cost) | |
| actor_scores = torch.cat(actor_scores, dim=0) | |
| actor_score_mean = actor_scores.mean().item() | |
| if record_scores is None: | |
| baseline.load_state_dict(actor.state_dict()) | |
| return actor_scores | |
| baseline_score_mean = record_scores.mean().item() | |
| if actor_score_mean < baseline_score_mean: | |
| print(f"\nBaseline updated: {baseline_score_mean:.4f} β {actor_score_mean:.4f}\n", flush=True) | |
| baseline.load_state_dict(actor.state_dict()) | |
| return actor_scores | |
| else: | |
| print(f"\nNo improvement: {actor_score_mean:.4f} β₯ {baseline_score_mean:.4f}\n", flush=True) | |
| return record_scores | |