Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| class VRP_Dataset(Dataset): | |
| def __init__(self, dataset_size, num_nodes, num_depots, dataset_path, device='cpu', *args, **kwargs): | |
| super().__init__() | |
| self.device = device | |
| self.dataset_size = dataset_size | |
| self.num_nodes = num_nodes | |
| self.num_depots = num_depots | |
| # Load CSV data | |
| #raw_data = pd.read_csv(dataset_path) | |
| raw_data = pd.read_csv(dataset_path, nrows=6000) | |
| if len(raw_data) < dataset_size * num_nodes: | |
| raise ValueError("Not enough rows in CSV to build required dataset") | |
| # Randomly sample and reshape | |
| coords = torch.tensor(raw_data[['lng', 'lat']].values[:dataset_size * num_nodes], dtype=torch.float32) | |
| node_positions = coords.view(dataset_size, num_nodes, 2) | |
| self.node_positions = node_positions | |
| # Fleet data | |
| num_cars = num_nodes | |
| launch_time = torch.zeros(dataset_size, num_cars, 1) | |
| car_start_node = torch.randint(low=0, high=num_depots, size=(dataset_size, num_cars, 1)) | |
| self.fleet_data = { | |
| 'start_time': launch_time, | |
| 'car_start_node': car_start_node, | |
| } | |
| # Graph data | |
| a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(dataset_size, num_cars, 1) | |
| b = car_start_node.repeat(1, 1, num_nodes) | |
| depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2) | |
| start_times = (torch.rand(dataset_size, num_nodes, 1) * 2 + 3) * (1 - depot) | |
| end_times = start_times + (0.1 + 0.5 * torch.rand(dataset_size, num_nodes, 1)) * (1 - depot) | |
| distance_matrix = self.compute_distance_matrix(node_positions) | |
| time_matrix = distance_matrix.clone() | |
| self.graph_data = { | |
| 'start_times': start_times, | |
| 'end_times': end_times, | |
| 'depot': depot, | |
| 'node_vector': node_positions, | |
| 'distance_matrix': distance_matrix, | |
| 'time_matrix': time_matrix | |
| } | |
| def compute_distance_matrix(self, node_positions): | |
| x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1) | |
| y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1) | |
| distance = torch.sqrt(((x - y) ** 2).sum(dim=3)) | |
| return distance | |
| def __getitem__(self, idx): | |
| graph = {key: self.graph_data[key][idx].unsqueeze(0).to(self.device) for key in self.graph_data} | |
| fleet = {key: self.fleet_data[key][idx].unsqueeze(0).to(self.device) for key in self.fleet_data} | |
| return graph, fleet | |
| def __len__(self): | |
| return self.dataset_size | |
| def collate(self, batch): | |
| graph_data = {key: torch.cat([item[0][key] for item in batch], dim=0) for key in self.graph_data} | |
| fleet_data = {key: torch.cat([item[1][key] for item in batch], dim=0) for key in self.fleet_data} | |
| return graph_data, fleet_data | |
| def get_batch(self, idx, batch_size=10): | |
| return self.collate([self.__getitem__(i) for i in range(idx, idx + batch_size)]) | |
| def get_data(self): | |
| return self.graph_data, self.fleet_data | |
| def model_input_length(self): | |
| return 3 + self.graph_data['node_vector'].shape[2] | |
| def save_data(self, fp): | |
| data = (self.graph_data, self.fleet_data) | |
| with open(fp, 'wb') as f: | |
| torch.save(data, f) | |