Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,7 +31,7 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
| 31 |
raw_cm = compute_average_confusion_matrix(raw_folder)
|
| 32 |
if raw_cm is not None:
|
| 33 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 34 |
-
|
| 35 |
raw_img = Image.open(raw_cm_path)
|
| 36 |
else:
|
| 37 |
raw_img = None
|
|
@@ -40,14 +40,28 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
| 40 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
| 41 |
if embeddings_cm is not None:
|
| 42 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 43 |
-
|
| 44 |
embeddings_img = Image.open(embeddings_cm_path)
|
| 45 |
else:
|
| 46 |
embeddings_img = None
|
| 47 |
|
| 48 |
return raw_img, embeddings_img
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# Function to compute the average confusion matrix across CSV files in a folder
|
| 52 |
def compute_average_confusion_matrix(folder):
|
| 53 |
confusion_matrices = []
|
|
|
|
| 31 |
raw_cm = compute_average_confusion_matrix(raw_folder)
|
| 32 |
if raw_cm is not None:
|
| 33 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 34 |
+
plot_confusion_matrix_beamPred(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path)
|
| 35 |
raw_img = Image.open(raw_cm_path)
|
| 36 |
else:
|
| 37 |
raw_img = None
|
|
|
|
| 40 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
| 41 |
if embeddings_cm is not None:
|
| 42 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 43 |
+
plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix ({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
|
| 44 |
embeddings_img = Image.open(embeddings_cm_path)
|
| 45 |
else:
|
| 46 |
embeddings_img = None
|
| 47 |
|
| 48 |
return raw_img, embeddings_img
|
| 49 |
|
| 50 |
+
def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
| 51 |
+
plt.figure(figsize=(8, 6))
|
| 52 |
+
plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
|
| 53 |
+
plt.title(title)
|
| 54 |
+
plt.colorbar()
|
| 55 |
+
tick_marks = np.arange(len(classes))
|
| 56 |
+
plt.xticks(tick_marks, classes, rotation=45)
|
| 57 |
+
plt.yticks(tick_marks, classes)
|
| 58 |
|
| 59 |
+
plt.tight_layout()
|
| 60 |
+
plt.ylabel('True label')
|
| 61 |
+
plt.xlabel('Predicted label')
|
| 62 |
+
plt.savefig(save_path)
|
| 63 |
+
plt.close()
|
| 64 |
+
|
| 65 |
# Function to compute the average confusion matrix across CSV files in a folder
|
| 66 |
def compute_average_confusion_matrix(folder):
|
| 67 |
confusion_matrices = []
|