Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,7 +26,7 @@ def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
|
|
| 26 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 27 |
plot_confusion_matrix_beamPred(raw_cm,
|
| 28 |
classes=np.arange(raw_cm.shape[0]),
|
| 29 |
-
title=f"
|
| 30 |
save_path=raw_cm_path,
|
| 31 |
theme=theme)
|
| 32 |
raw_img = Image.open(raw_cm_path)
|
|
@@ -39,7 +39,7 @@ def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
|
|
| 39 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 40 |
plot_confusion_matrix_beamPred(embeddings_cm,
|
| 41 |
classes=np.arange(embeddings_cm.shape[0]),
|
| 42 |
-
title=f"
|
| 43 |
save_path=embeddings_cm_path,
|
| 44 |
theme=theme)
|
| 45 |
embeddings_img = Image.open(embeddings_cm_path)
|
|
@@ -191,14 +191,14 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path, light_mode=F
|
|
| 191 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 192 |
|
| 193 |
# Add F1-score to the title
|
| 194 |
-
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=
|
| 195 |
|
| 196 |
# Customize tick labels for light/dark mode
|
| 197 |
-
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
| 198 |
-
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
| 199 |
|
| 200 |
-
plt.ylabel('True label', color=text_color, fontsize=
|
| 201 |
-
plt.xlabel('Predicted label', color=text_color, fontsize=
|
| 202 |
plt.tight_layout()
|
| 203 |
|
| 204 |
# Save the plot as an image
|
|
@@ -220,14 +220,14 @@ def display_confusion_matrices_los(percentage):
|
|
| 220 |
raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
|
| 221 |
raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 222 |
raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
|
| 223 |
-
f"
|
| 224 |
raw_cm_img_path)
|
| 225 |
|
| 226 |
# Process embeddings confusion matrix
|
| 227 |
embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
|
| 228 |
embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 229 |
embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
|
| 230 |
-
f"
|
| 231 |
embeddings_cm_img_path)
|
| 232 |
|
| 233 |
return raw_img, embeddings_img
|
|
@@ -362,14 +362,14 @@ def plot_confusion_matrix(y_true, y_pred, title, light_mode=False):
|
|
| 362 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 363 |
|
| 364 |
# Add F1-score to the title
|
| 365 |
-
plt.title(f"{title}\
|
| 366 |
|
| 367 |
# Customize tick labels for dark mode
|
| 368 |
-
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
| 369 |
-
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=
|
| 370 |
|
| 371 |
-
plt.ylabel('True label', color=text_color, fontsize=
|
| 372 |
-
plt.xlabel('Predicted label', color=text_color, fontsize=
|
| 373 |
plt.tight_layout()
|
| 374 |
|
| 375 |
# Save the plot as an image
|
|
|
|
| 26 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 27 |
plot_confusion_matrix_beamPred(raw_cm,
|
| 28 |
classes=np.arange(raw_cm.shape[0]),
|
| 29 |
+
title=f"Confusion Matrix (Raw Channels)\n{data_percentage}% data, {task_complexity} beams",
|
| 30 |
save_path=raw_cm_path,
|
| 31 |
theme=theme)
|
| 32 |
raw_img = Image.open(raw_cm_path)
|
|
|
|
| 39 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 40 |
plot_confusion_matrix_beamPred(embeddings_cm,
|
| 41 |
classes=np.arange(embeddings_cm.shape[0]),
|
| 42 |
+
title=f"Confusion Matrix (LWM Embeddings)\n{data_percentage}% data, {task_complexity} beams",
|
| 43 |
save_path=embeddings_cm_path,
|
| 44 |
theme=theme)
|
| 45 |
embeddings_img = Image.open(embeddings_cm_path)
|
|
|
|
| 191 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 192 |
|
| 193 |
# Add F1-score to the title
|
| 194 |
+
plt.title(f"{title}\n(F1 Score: {f1:.3f})", color=text_color, fontsize=24)
|
| 195 |
|
| 196 |
# Customize tick labels for light/dark mode
|
| 197 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
| 198 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
| 199 |
|
| 200 |
+
plt.ylabel('True label', color=text_color, fontsize=18)
|
| 201 |
+
plt.xlabel('Predicted label', color=text_color, fontsize=18)
|
| 202 |
plt.tight_layout()
|
| 203 |
|
| 204 |
# Save the plot as an image
|
|
|
|
| 220 |
raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
|
| 221 |
raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 222 |
raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
|
| 223 |
+
f"Confusion Matrix (Raw Channels)\n{percentage:.1f}% data",
|
| 224 |
raw_cm_img_path)
|
| 225 |
|
| 226 |
# Process embeddings confusion matrix
|
| 227 |
embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
|
| 228 |
embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 229 |
embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
|
| 230 |
+
f"Confusion Matrix (LWM Embeddings)\n{percentage:.1f}% data",
|
| 231 |
embeddings_cm_img_path)
|
| 232 |
|
| 233 |
return raw_img, embeddings_img
|
|
|
|
| 362 |
sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, cbar=False, annot_kws={"size": 12}, linewidths=0.5, linecolor='white')
|
| 363 |
|
| 364 |
# Add F1-score to the title
|
| 365 |
+
plt.title(f"{title}\nF1 Score: {f1:.3f}", color=text_color, fontsize=23)
|
| 366 |
|
| 367 |
# Customize tick labels for dark mode
|
| 368 |
+
plt.xticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
| 369 |
+
plt.yticks([0.5, 1.5], labels=['Class 0', 'Class 1'], color=text_color, fontsize=14)
|
| 370 |
|
| 371 |
+
plt.ylabel('True label', color=text_color, fontsize=18)
|
| 372 |
+
plt.xlabel('Predicted label', color=text_color, fontsize=18)
|
| 373 |
plt.tight_layout()
|
| 374 |
|
| 375 |
# Save the plot as an image
|