Spaces:
Runtime error
Runtime error
Update build_data.py
Browse files- build_data.py +19 -24
build_data.py
CHANGED
|
@@ -1,33 +1,27 @@
|
|
| 1 |
import os
|
| 2 |
from datetime import datetime
|
| 3 |
-
import sys
|
| 4 |
-
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
-
from torch.utils.data import Dataset
|
| 9 |
-
|
| 10 |
-
# Ensure project directory is included
|
| 11 |
|
| 12 |
class Raw_VRP_Data(object):
|
| 13 |
|
| 14 |
-
def __init__(self, dataset_size=1000, num_nodes=
|
| 15 |
-
|
| 16 |
self.dataset_size = dataset_size
|
| 17 |
self.num_nodes = num_nodes
|
| 18 |
-
num_cars = num_nodes
|
| 19 |
self.num_depots = num_depots
|
|
|
|
| 20 |
|
| 21 |
-
#
|
| 22 |
launch_time = torch.zeros(self.dataset_size, num_cars, 1)
|
| 23 |
car_start_node = torch.randint(low=0, high=num_depots, size=(self.dataset_size, num_cars, 1))
|
| 24 |
|
| 25 |
-
fleet = {
|
| 26 |
'start_time': launch_time,
|
| 27 |
'car_start_node': car_start_node,
|
| 28 |
}
|
| 29 |
|
| 30 |
-
#
|
| 31 |
a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(self.dataset_size, num_cars, 1)
|
| 32 |
b = car_start_node.repeat(1, 1, num_nodes)
|
| 33 |
depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2)
|
|
@@ -35,11 +29,11 @@ class Raw_VRP_Data(object):
|
|
| 35 |
start_times = (torch.rand(self.dataset_size, num_nodes, 1) * 2 + 3) * (1 - depot)
|
| 36 |
end_times = start_times + (0.1 + 0.5 * torch.rand(self.dataset_size, num_nodes, 1)) * (1 - depot)
|
| 37 |
|
| 38 |
-
node_positions = torch.rand(dataset_size, num_nodes, 2)
|
| 39 |
distance_matrix = self.compute_distance_matrix(node_positions)
|
| 40 |
time_matrix = distance_matrix.clone()
|
| 41 |
|
| 42 |
-
graph = {
|
| 43 |
'start_times': start_times,
|
| 44 |
'end_times': end_times,
|
| 45 |
'depot': depot,
|
|
@@ -49,14 +43,14 @@ class Raw_VRP_Data(object):
|
|
| 49 |
}
|
| 50 |
|
| 51 |
self.data = {
|
| 52 |
-
'fleet': fleet,
|
| 53 |
-
'graph': graph
|
| 54 |
}
|
| 55 |
|
| 56 |
def compute_distance_matrix(self, node_positions):
|
| 57 |
x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1)
|
| 58 |
y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
|
| 59 |
-
distance = (((x - y)**2).sum(dim=3))
|
| 60 |
return distance
|
| 61 |
|
| 62 |
def get_data(self):
|
|
@@ -67,18 +61,19 @@ class Raw_VRP_Data(object):
|
|
| 67 |
|
| 68 |
|
| 69 |
if __name__ == '__main__':
|
| 70 |
-
|
| 71 |
-
size =
|
| 72 |
num_nodes = 30
|
| 73 |
num_depots = 1
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
start = datetime.now()
|
| 76 |
|
| 77 |
raw_data = Raw_VRP_Data(dataset_size=size, num_nodes=num_nodes, num_depots=num_depots)
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
end = datetime.now()
|
| 83 |
-
duration = (end - start).seconds
|
| 84 |
-
print(f"Data generation completed in {duration} seconds.")
|
|
|
|
| 1 |
import os
|
| 2 |
from datetime import datetime
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class Raw_VRP_Data(object):
|
| 8 |
|
| 9 |
+
def __init__(self, dataset_size=1000, num_nodes=30, num_depots=1):
|
|
|
|
| 10 |
self.dataset_size = dataset_size
|
| 11 |
self.num_nodes = num_nodes
|
|
|
|
| 12 |
self.num_depots = num_depots
|
| 13 |
+
num_cars = num_nodes # كل Node له Car افتراضي
|
| 14 |
|
| 15 |
+
# Fleet data
|
| 16 |
launch_time = torch.zeros(self.dataset_size, num_cars, 1)
|
| 17 |
car_start_node = torch.randint(low=0, high=num_depots, size=(self.dataset_size, num_cars, 1))
|
| 18 |
|
| 19 |
+
self.fleet = {
|
| 20 |
'start_time': launch_time,
|
| 21 |
'car_start_node': car_start_node,
|
| 22 |
}
|
| 23 |
|
| 24 |
+
# Graph data
|
| 25 |
a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(self.dataset_size, num_cars, 1)
|
| 26 |
b = car_start_node.repeat(1, 1, num_nodes)
|
| 27 |
depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2)
|
|
|
|
| 29 |
start_times = (torch.rand(self.dataset_size, num_nodes, 1) * 2 + 3) * (1 - depot)
|
| 30 |
end_times = start_times + (0.1 + 0.5 * torch.rand(self.dataset_size, num_nodes, 1)) * (1 - depot)
|
| 31 |
|
| 32 |
+
node_positions = torch.rand(self.dataset_size, num_nodes, 2)
|
| 33 |
distance_matrix = self.compute_distance_matrix(node_positions)
|
| 34 |
time_matrix = distance_matrix.clone()
|
| 35 |
|
| 36 |
+
self.graph = {
|
| 37 |
'start_times': start_times,
|
| 38 |
'end_times': end_times,
|
| 39 |
'depot': depot,
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
self.data = {
|
| 46 |
+
'fleet': self.fleet,
|
| 47 |
+
'graph': self.graph
|
| 48 |
}
|
| 49 |
|
| 50 |
def compute_distance_matrix(self, node_positions):
|
| 51 |
x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1)
|
| 52 |
y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
|
| 53 |
+
distance = torch.sqrt(((x - y) ** 2).sum(dim=3))
|
| 54 |
return distance
|
| 55 |
|
| 56 |
def get_data(self):
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
if __name__ == '__main__':
|
| 64 |
+
# إعدادات من params.json
|
| 65 |
+
size = 5000
|
| 66 |
num_nodes = 30
|
| 67 |
num_depots = 1
|
| 68 |
|
| 69 |
+
save_path = os.path.join(os.getcwd(), 'VRP_data.pt')
|
| 70 |
+
|
| 71 |
+
print("Generating data...")
|
| 72 |
start = datetime.now()
|
| 73 |
|
| 74 |
raw_data = Raw_VRP_Data(dataset_size=size, num_nodes=num_nodes, num_depots=num_depots)
|
| 75 |
+
raw_data.save_data(save_path)
|
| 76 |
|
| 77 |
+
duration = (datetime.now() - start).seconds
|
| 78 |
+
print(f"✅ Data generation completed in {duration} seconds.")
|
| 79 |
+
print(f"📦 Saved to: {save_path}")
|
|
|
|
|
|
|
|
|