Spaces:
Runtime error
Runtime error
Update dataloader.py
Browse files- dataloader.py +14 -21
dataloader.py
CHANGED
|
@@ -1,46 +1,39 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from torch.utils.data import Dataset
|
| 4 |
-
|
| 5 |
import numpy as np
|
| 6 |
-
from random import shuffle
|
| 7 |
import os
|
| 8 |
-
import pandas as pd
|
| 9 |
|
| 10 |
-
class VRP_Dataset(Dataset):
|
| 11 |
|
|
|
|
| 12 |
def __init__(self, dataset_size, num_nodes, num_depots, dataset_path, device='cpu', *args, **kwargs):
|
| 13 |
super().__init__()
|
| 14 |
-
|
| 15 |
self.device = device
|
| 16 |
self.dataset_size = dataset_size
|
| 17 |
self.num_nodes = num_nodes
|
| 18 |
self.num_depots = num_depots
|
| 19 |
|
| 20 |
-
# Load
|
| 21 |
raw_data = pd.read_csv(dataset_path)
|
| 22 |
-
if len(raw_data) < dataset_size:
|
| 23 |
-
raise ValueError("
|
| 24 |
-
|
| 25 |
-
sampled_data = raw_data.sample(n=dataset_size, random_state=42).reset_index(drop=True)
|
| 26 |
-
|
| 27 |
-
# Extract coordinates (assuming columns named 'longitude', 'latitude')
|
| 28 |
-
coords = torch.tensor(sampled_data[['longitude', 'latitude']].values, dtype=torch.float32)
|
| 29 |
|
| 30 |
-
#
|
|
|
|
| 31 |
node_positions = coords.view(dataset_size, num_nodes, 2)
|
| 32 |
self.node_positions = node_positions
|
| 33 |
|
| 34 |
-
#
|
| 35 |
num_cars = num_nodes
|
| 36 |
launch_time = torch.zeros(dataset_size, num_cars, 1)
|
| 37 |
car_start_node = torch.randint(low=0, high=num_depots, size=(dataset_size, num_cars, 1))
|
| 38 |
self.fleet_data = {
|
| 39 |
'start_time': launch_time,
|
| 40 |
-
'car_start_node': car_start_node
|
| 41 |
}
|
| 42 |
|
| 43 |
-
#
|
| 44 |
a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(dataset_size, num_cars, 1)
|
| 45 |
b = car_start_node.repeat(1, 1, num_nodes)
|
| 46 |
depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2)
|
|
@@ -63,13 +56,13 @@ class VRP_Dataset(Dataset):
|
|
| 63 |
def compute_distance_matrix(self, node_positions):
|
| 64 |
x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1)
|
| 65 |
y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
|
| 66 |
-
distance = (((x - y) ** 2).sum(dim=3))
|
| 67 |
return distance
|
| 68 |
|
| 69 |
def __getitem__(self, idx):
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return
|
| 73 |
|
| 74 |
def __len__(self):
|
| 75 |
return self.dataset_size
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from torch.utils.data import Dataset
|
| 4 |
+
import pandas as pd
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import os
|
|
|
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
+
class VRP_Dataset(Dataset):
|
| 10 |
def __init__(self, dataset_size, num_nodes, num_depots, dataset_path, device='cpu', *args, **kwargs):
|
| 11 |
super().__init__()
|
|
|
|
| 12 |
self.device = device
|
| 13 |
self.dataset_size = dataset_size
|
| 14 |
self.num_nodes = num_nodes
|
| 15 |
self.num_depots = num_depots
|
| 16 |
|
| 17 |
+
# Load CSV data
|
| 18 |
raw_data = pd.read_csv(dataset_path)
|
| 19 |
+
if len(raw_data) < dataset_size * num_nodes:
|
| 20 |
+
raise ValueError("Not enough rows in CSV to build required dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
# Randomly sample and reshape
|
| 23 |
+
coords = torch.tensor(raw_data[['longitude', 'latitude']].values[:dataset_size * num_nodes], dtype=torch.float32)
|
| 24 |
node_positions = coords.view(dataset_size, num_nodes, 2)
|
| 25 |
self.node_positions = node_positions
|
| 26 |
|
| 27 |
+
# Fleet data
|
| 28 |
num_cars = num_nodes
|
| 29 |
launch_time = torch.zeros(dataset_size, num_cars, 1)
|
| 30 |
car_start_node = torch.randint(low=0, high=num_depots, size=(dataset_size, num_cars, 1))
|
| 31 |
self.fleet_data = {
|
| 32 |
'start_time': launch_time,
|
| 33 |
+
'car_start_node': car_start_node,
|
| 34 |
}
|
| 35 |
|
| 36 |
+
# Graph data
|
| 37 |
a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(dataset_size, num_cars, 1)
|
| 38 |
b = car_start_node.repeat(1, 1, num_nodes)
|
| 39 |
depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2)
|
|
|
|
| 56 |
def compute_distance_matrix(self, node_positions):
|
| 57 |
x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1)
|
| 58 |
y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
|
| 59 |
+
distance = torch.sqrt(((x - y) ** 2).sum(dim=3))
|
| 60 |
return distance
|
| 61 |
|
| 62 |
def __getitem__(self, idx):
|
| 63 |
+
graph = {key: self.graph_data[key][idx].unsqueeze(0).to(self.device) for key in self.graph_data}
|
| 64 |
+
fleet = {key: self.fleet_data[key][idx].unsqueeze(0).to(self.device) for key in self.fleet_data}
|
| 65 |
+
return graph, fleet
|
| 66 |
|
| 67 |
def __len__(self):
|
| 68 |
return self.dataset_size
|