Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,16 +12,8 @@ from sklearn.metrics import confusion_matrix
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import pandas as pd
|
| 14 |
|
| 15 |
-
# Paths to the predefined images folder
|
| 16 |
-
RAW_PATH = os.path.join("images", "raw")
|
| 17 |
-
EMBEDDINGS_PATH = os.path.join("images", "embeddings")
|
| 18 |
-
|
| 19 |
-
# Specific values for percentage of data for training
|
| 20 |
-
percentage_values = (np.arange(9) + 1)*10
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
|
|
|
| 25 |
def beam_prediction_task(data_percentage, task_complexity):
|
| 26 |
# Folder naming convention based on input_type, data_percentage, and task_complexity
|
| 27 |
raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
|
|
@@ -92,40 +84,6 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
|
| 92 |
plt.savefig(save_path)
|
| 93 |
plt.close()
|
| 94 |
|
| 95 |
-
|
| 96 |
-
#def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
| 97 |
-
# plt.figure(figsize=(8, 6))
|
| 98 |
-
# plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
|
| 99 |
-
# plt.title(title)
|
| 100 |
-
# plt.colorbar()
|
| 101 |
-
# tick_marks = np.arange(len(classes))
|
| 102 |
-
# plt.xticks(tick_marks, classes, rotation=45)
|
| 103 |
-
# plt.yticks(tick_marks, classes)
|
| 104 |
-
#
|
| 105 |
-
# plt.tight_layout()
|
| 106 |
-
# plt.ylabel('True label')
|
| 107 |
-
# plt.xlabel('Predicted label')
|
| 108 |
-
# plt.savefig(save_path)
|
| 109 |
-
# plt.close()
|
| 110 |
-
|
| 111 |
-
# Function to compute the average confusion matrix across CSV files in a folder
|
| 112 |
-
#def compute_average_confusion_matrix(folder):
|
| 113 |
-
# confusion_matrices = []
|
| 114 |
-
# for file in os.listdir(folder):
|
| 115 |
-
# if file.endswith(".csv"):
|
| 116 |
-
# data = pd.read_csv(os.path.join(folder, file))
|
| 117 |
-
# y_true = data["Target"]
|
| 118 |
-
# y_pred = data["Top-1 Prediction"]
|
| 119 |
-
# num_labels = len(np.unique(y_true))
|
| 120 |
-
# cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
|
| 121 |
-
# confusion_matrices.append(cm)
|
| 122 |
-
#
|
| 123 |
-
# if confusion_matrices:
|
| 124 |
-
# avg_cm = np.mean(confusion_matrices, axis=0)
|
| 125 |
-
# return avg_cm
|
| 126 |
-
# else:
|
| 127 |
-
# return None
|
| 128 |
-
|
| 129 |
def compute_average_confusion_matrix(folder):
|
| 130 |
confusion_matrices = []
|
| 131 |
max_num_labels = 0
|
|
@@ -162,10 +120,99 @@ def compute_average_confusion_matrix(folder):
|
|
| 162 |
else:
|
| 163 |
return None
|
| 164 |
|
|
|
|
| 165 |
|
| 166 |
|
|
|
|
|
|
|
| 167 |
|
|
|
|
|
|
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Custom class to capture print output
|
| 171 |
class PrintCapture(io.StringIO):
|
|
@@ -410,7 +457,7 @@ def process_hdf5_file(uploaded_file, percentage_idx):
|
|
| 410 |
os.chdir(original_dir)
|
| 411 |
sys.stdout = sys.__stdout__ # Reset print statements
|
| 412 |
|
| 413 |
-
|
| 414 |
with gr.Blocks(css="""
|
| 415 |
.slider-container {
|
| 416 |
display: inline-block;
|
|
@@ -439,17 +486,35 @@ with gr.Blocks(css="""
|
|
| 439 |
# Separate Tab for LoS/NLoS Classification Task
|
| 440 |
with gr.Tab("LoS/NLoS Classification Task"):
|
| 441 |
gr.Markdown("### LoS/NLoS Classification Task")
|
| 442 |
-
file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
with gr.Row():
|
| 447 |
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
|
| 448 |
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
|
| 449 |
output_textbox = gr.Textbox(label="Console Output", lines=10)
|
| 450 |
|
| 451 |
-
#
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
# Launch the app
|
| 455 |
if __name__ == "__main__":
|
|
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import pandas as pd
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
#################### BEAM PREDICTION #########################}
|
| 17 |
def beam_prediction_task(data_percentage, task_complexity):
|
| 18 |
# Folder naming convention based on input_type, data_percentage, and task_complexity
|
| 19 |
raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
|
|
|
|
| 84 |
plt.savefig(save_path)
|
| 85 |
plt.close()
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def compute_average_confusion_matrix(folder):
|
| 88 |
confusion_matrices = []
|
| 89 |
max_num_labels = 0
|
|
|
|
| 120 |
else:
|
| 121 |
return None
|
| 122 |
|
| 123 |
+
########################## LOS/NLOS CLASSIFICATION #############################3
|
| 124 |
|
| 125 |
|
| 126 |
+
# Paths to the predefined images folder
|
| 127 |
+
LOS_PATH = "images_LoS"
|
| 128 |
|
| 129 |
+
# Define the percentage values
|
| 130 |
+
percentage_values_los = np.linspace(0.1, 1, 20) * 100 # 20 percentage values
|
| 131 |
|
| 132 |
+
# Function to compute confusion matrix and plot it
|
| 133 |
+
def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
|
| 134 |
+
# Load CSV file
|
| 135 |
+
data = pd.read_csv(csv_file_path)
|
| 136 |
+
|
| 137 |
+
# Extract ground truth and predictions
|
| 138 |
+
y_true = data['ground-truth']
|
| 139 |
+
y_pred = data['predicted']
|
| 140 |
+
|
| 141 |
+
# Compute confusion matrix
|
| 142 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 143 |
+
|
| 144 |
+
# Plot the confusion matrix
|
| 145 |
+
plt.figure(figsize=(5, 5))
|
| 146 |
+
plt.imshow(cm, interpolation='nearest', cmap='Blues')
|
| 147 |
+
plt.title(title)
|
| 148 |
+
plt.colorbar()
|
| 149 |
+
plt.xticks([0, 1], labels=['Class 0', 'Class 1'])
|
| 150 |
+
plt.yticks([0, 1], labels=['Class 0', 'Class 1'])
|
| 151 |
+
|
| 152 |
+
# Annotate the confusion matrix
|
| 153 |
+
thresh = cm.max() / 2
|
| 154 |
+
for i in range(cm.shape[0]):
|
| 155 |
+
for j in range(cm.shape[1]):
|
| 156 |
+
plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
|
| 157 |
+
color="white" if cm[i, j] > thresh else "black")
|
| 158 |
+
|
| 159 |
+
plt.ylabel('True label')
|
| 160 |
+
plt.xlabel('Predicted label')
|
| 161 |
+
plt.tight_layout()
|
| 162 |
+
|
| 163 |
+
# Save the plot as an image
|
| 164 |
+
plt.savefig(save_path)
|
| 165 |
+
plt.close()
|
| 166 |
+
|
| 167 |
+
# Return the saved image
|
| 168 |
+
return Image.open(save_path)
|
| 169 |
+
|
| 170 |
+
# Function to load confusion matrix based on percentage and input_type
|
| 171 |
+
def display_confusion_matrices_los(percentage_idx):
|
| 172 |
+
percentage = percentage_values_los[percentage_idx]
|
| 173 |
+
|
| 174 |
+
# Construct folder names
|
| 175 |
+
raw_folder = os.path.join(LOS_PATH, f"raw_{percentage/100:.3f}_los_noTraining")
|
| 176 |
+
embeddings_folder = os.path.join(LOS_PATH, f"embedding_{percentage/100:.3f}_los_noTraining")
|
| 177 |
+
|
| 178 |
+
# Process raw confusion matrix
|
| 179 |
+
raw_csv_file = os.path.join(raw_folder, "confusion_matrix.csv")
|
| 180 |
+
raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 181 |
+
raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
|
| 182 |
+
f"Raw Confusion Matrix ({percentage:.1f}% data)",
|
| 183 |
+
raw_cm_img_path)
|
| 184 |
+
|
| 185 |
+
# Process embeddings confusion matrix
|
| 186 |
+
embeddings_csv_file = os.path.join(embeddings_folder, "confusion_matrix.csv")
|
| 187 |
+
embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 188 |
+
embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
|
| 189 |
+
f"Embeddings Confusion Matrix ({percentage:.1f}% data)",
|
| 190 |
+
embeddings_cm_img_path)
|
| 191 |
+
|
| 192 |
+
return raw_img, embeddings_img
|
| 193 |
+
|
| 194 |
+
# Main function to handle user choice
|
| 195 |
+
def handle_user_choice(choice, percentage_idx=None, uploaded_file=None):
|
| 196 |
+
if choice == "Use Predefined Data":
|
| 197 |
+
return display_confusion_matrices_los(percentage_idx)
|
| 198 |
+
elif choice == "Upload Dataset":
|
| 199 |
+
if uploaded_file is not None:
|
| 200 |
+
return process_hdf5_file(uploaded_file, percentage_idx)
|
| 201 |
+
else:
|
| 202 |
+
return "Please upload a dataset", "Please upload a dataset"
|
| 203 |
+
else:
|
| 204 |
+
return "Invalid choice", "Invalid choice"
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
|
| 217 |
# Custom class to capture print output
|
| 218 |
class PrintCapture(io.StringIO):
|
|
|
|
| 457 |
os.chdir(original_dir)
|
| 458 |
sys.stdout = sys.__stdout__ # Reset print statements
|
| 459 |
|
| 460 |
+
######################## Define the Gradio interface ###############################
|
| 461 |
with gr.Blocks(css="""
|
| 462 |
.slider-container {
|
| 463 |
display: inline-block;
|
|
|
|
| 486 |
# Separate Tab for LoS/NLoS Classification Task
|
| 487 |
with gr.Tab("LoS/NLoS Classification Task"):
|
| 488 |
gr.Markdown("### LoS/NLoS Classification Task")
|
|
|
|
| 489 |
|
| 490 |
+
# Radio button for user choice: predefined data or upload dataset
|
| 491 |
+
choice_radio = gr.Radio(choices=["Use Predefined Data", "Upload Dataset"], label="Choose how to proceed", value="Use Predefined Data")
|
| 492 |
+
|
| 493 |
+
# Dropdown for selecting percentage for predefined data
|
| 494 |
+
percentage_dropdown_los = gr.Dropdown(choices=list(range(20)), value=0, label="Percentage of Data for Training")
|
| 495 |
+
|
| 496 |
+
# File uploader for dataset (only visible if user chooses to upload a dataset)
|
| 497 |
+
file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"], visible=False)
|
| 498 |
+
|
| 499 |
+
# Confusion matrices display
|
| 500 |
with gr.Row():
|
| 501 |
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
|
| 502 |
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
|
| 503 |
output_textbox = gr.Textbox(label="Console Output", lines=10)
|
| 504 |
|
| 505 |
+
# Update the file uploader visibility based on user choice
|
| 506 |
+
def toggle_file_input(choice):
|
| 507 |
+
return gr.update(visible=(choice == "Upload Dataset"))
|
| 508 |
+
|
| 509 |
+
choice_radio.change(fn=toggle_file_input, inputs=[choice_radio], outputs=file_input)
|
| 510 |
+
|
| 511 |
+
# When user makes a choice, update the display
|
| 512 |
+
choice_radio.change(fn=handle_user_choice, inputs=[choice_radio, percentage_dropdown_los, file_input],
|
| 513 |
+
outputs=[raw_img_los, embeddings_img_los, output_textbox])
|
| 514 |
+
|
| 515 |
+
# When percentage slider changes (for predefined data)
|
| 516 |
+
percentage_dropdown_los.change(fn=handle_user_choice, inputs=[choice_radio, percentage_dropdown_los, file_input],
|
| 517 |
+
outputs=[raw_img_los, embeddings_img_los, output_textbox])
|
| 518 |
|
| 519 |
# Launch the app
|
| 520 |
if __name__ == "__main__":
|