Spaces:
Runtime error
Runtime error
| import os | |
| from datetime import datetime | |
| import torch | |
| import numpy as np | |
| class Raw_VRP_Data(object): | |
| def __init__(self, dataset_size=1000, num_nodes=30, num_depots=1): | |
| self.dataset_size = dataset_size | |
| self.num_nodes = num_nodes | |
| self.num_depots = num_depots | |
| num_cars = num_nodes # كل Node له Car افتراضي | |
| # Fleet data | |
| launch_time = torch.zeros(self.dataset_size, num_cars, 1) | |
| car_start_node = torch.randint(low=0, high=num_depots, size=(self.dataset_size, num_cars, 1)) | |
| self.fleet = { | |
| 'start_time': launch_time, | |
| 'car_start_node': car_start_node, | |
| } | |
| # Graph data | |
| a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(self.dataset_size, num_cars, 1) | |
| b = car_start_node.repeat(1, 1, num_nodes) | |
| depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2) | |
| start_times = (torch.rand(self.dataset_size, num_nodes, 1) * 2 + 3) * (1 - depot) | |
| end_times = start_times + (0.1 + 0.5 * torch.rand(self.dataset_size, num_nodes, 1)) * (1 - depot) | |
| node_positions = torch.rand(self.dataset_size, num_nodes, 2) | |
| distance_matrix = self.compute_distance_matrix(node_positions) | |
| time_matrix = distance_matrix.clone() | |
| self.graph = { | |
| 'start_times': start_times, | |
| 'end_times': end_times, | |
| 'depot': depot, | |
| 'node_vector': node_positions, | |
| 'distance_matrix': distance_matrix, | |
| 'time_matrix': time_matrix | |
| } | |
| self.data = { | |
| 'fleet': self.fleet, | |
| 'graph': self.graph | |
| } | |
| def compute_distance_matrix(self, node_positions): | |
| x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1) | |
| y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1) | |
| distance = torch.sqrt(((x - y) ** 2).sum(dim=3)) | |
| return distance | |
| def get_data(self): | |
| return self.data | |
| def save_data(self, fp): | |
| torch.save(self.data, fp) | |
| if __name__ == '__main__': | |
| # إعدادات من params.json | |
| size = 5000 | |
| num_nodes = 30 | |
| num_depots = 1 | |
| save_path = os.path.join(os.getcwd(), 'VRP_data.pt') | |
| print("Generating data...") | |
| start = datetime.now() | |
| raw_data = Raw_VRP_Data(dataset_size=size, num_nodes=num_nodes, num_depots=num_depots) | |
| raw_data.save_data(save_path) | |
| duration = (datetime.now() - start).seconds | |
| print(f"✅ Data generation completed in {duration} seconds.") | |
| print(f"📦 Saved to: {save_path}") | |