Spaces:
Runtime error
Runtime error
Update train_test_utils/baseline.py
Browse files- train_test_utils/baseline.py +28 -19
train_test_utils/baseline.py
CHANGED
|
@@ -1,10 +1,22 @@
|
|
| 1 |
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from torch.nn.utils import clip_grad_norm_
|
| 4 |
from torch.utils.data import DataLoader
|
| 5 |
|
| 6 |
|
| 7 |
-
def update_baseline(actor, baseline, validation_set, record_scores, batch_size=100, threshold=0.95):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
val_dataloader = DataLoader(dataset=validation_set,
|
| 10 |
batch_size=batch_size,
|
|
@@ -17,25 +29,22 @@ def update_baseline(actor, baseline, validation_set, record_scores, batch_size=1
|
|
| 17 |
for batch in val_dataloader:
|
| 18 |
with torch.no_grad():
|
| 19 |
actor_output = actor(batch)
|
| 20 |
-
actor_cost = actor_output['total_time']
|
| 21 |
-
|
| 22 |
-
actor_scores.append(actor_cost)
|
| 23 |
-
actor_scores = torch.cat(actor_scores, dim=0)
|
| 24 |
|
|
|
|
|
|
|
| 25 |
|
| 26 |
if record_scores is None:
|
| 27 |
baseline.load_state_dict(actor.state_dict())
|
| 28 |
-
|
| 29 |
-
return record_scores
|
| 30 |
-
else:
|
| 31 |
|
| 32 |
-
|
| 33 |
-
print('\n', flush=True)
|
| 34 |
-
print('baseline updated', flush=True)
|
| 35 |
-
print('\n', flush=True)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
|
|
|
| 2 |
from torch.utils.data import DataLoader
|
| 3 |
|
| 4 |
|
| 5 |
+
def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95):
|
| 6 |
+
"""
|
| 7 |
+
Evaluate the actor on the validation set and update the baseline if performance improves.
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
- actor: current model being trained
|
| 11 |
+
- baseline: model used as the performance reference
|
| 12 |
+
- validation_set: dataset used for evaluation
|
| 13 |
+
- record_scores: previously recorded baseline scores
|
| 14 |
+
- batch_size: batch size for validation
|
| 15 |
+
- threshold: (optional) threshold for improvement (not used in current implementation)
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
- updated record_scores
|
| 19 |
+
"""
|
| 20 |
|
| 21 |
val_dataloader = DataLoader(dataset=validation_set,
|
| 22 |
batch_size=batch_size,
|
|
|
|
| 29 |
for batch in val_dataloader:
|
| 30 |
with torch.no_grad():
|
| 31 |
actor_output = actor(batch)
|
| 32 |
+
actor_cost = actor_output['total_time'].view(-1)
|
| 33 |
+
actor_scores.append(actor_cost)
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
actor_scores = torch.cat(actor_scores, dim=0)
|
| 36 |
+
actor_score_mean = actor_scores.mean().item()
|
| 37 |
|
| 38 |
if record_scores is None:
|
| 39 |
baseline.load_state_dict(actor.state_dict())
|
| 40 |
+
return actor_scores
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
baseline_score_mean = record_scores.mean().item()
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
if actor_score_mean < baseline_score_mean:
|
| 45 |
+
print(f"\nBaseline updated: {baseline_score_mean:.4f} → {actor_score_mean:.4f}\n", flush=True)
|
| 46 |
+
baseline.load_state_dict(actor.state_dict())
|
| 47 |
+
return actor_scores
|
| 48 |
+
else:
|
| 49 |
+
print(f"\nNo improvement: {actor_score_mean:.4f} ≥ {baseline_score_mean:.4f}\n", flush=True)
|
| 50 |
+
return record_scores
|