Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from sklearn.metrics import f1_score
|
|
| 15 |
import seaborn as sns
|
| 16 |
|
| 17 |
#################### BEAM PREDICTION #########################}
|
| 18 |
-
def beam_prediction_task(data_percentage, task_complexity):
|
| 19 |
# Folder naming convention based on input_type, data_percentage, and task_complexity
|
| 20 |
raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
|
| 21 |
embeddings_folder = f"images/embedding_{data_percentage/100:.1f}_{task_complexity}"
|
|
@@ -24,7 +24,7 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
| 24 |
raw_cm = compute_average_confusion_matrix(raw_folder)
|
| 25 |
if raw_cm is not None:
|
| 26 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 27 |
-
plot_confusion_matrix_beamPred(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix\n({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path)
|
| 28 |
raw_img = Image.open(raw_cm_path)
|
| 29 |
else:
|
| 30 |
raw_img = None
|
|
@@ -33,15 +33,13 @@ def beam_prediction_task(data_percentage, task_complexity):
|
|
| 33 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
| 34 |
if embeddings_cm is not None:
|
| 35 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 36 |
-
plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix\n({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path)
|
| 37 |
embeddings_img = Image.open(embeddings_cm_path)
|
| 38 |
else:
|
| 39 |
embeddings_img = None
|
| 40 |
|
| 41 |
return raw_img, embeddings_img
|
| 42 |
|
| 43 |
-
from sklearn.metrics import f1_score
|
| 44 |
-
|
| 45 |
# Function to compute the F1-score based on the confusion matrix
|
| 46 |
def compute_f1_score(cm):
|
| 47 |
# Compute precision and recall
|
|
@@ -61,30 +59,12 @@ def compute_f1_score(cm):
|
|
| 61 |
f1 = np.nan_to_num(f1) # Replace NaN with 0
|
| 62 |
return np.mean(f1) # Return the mean F1-score across all classes
|
| 63 |
|
| 64 |
-
|
| 65 |
-
import seaborn as sns
|
| 66 |
-
import numpy as np
|
| 67 |
-
from PIL import Image
|
| 68 |
-
|
| 69 |
-
def plot_confusion_matrix_beamPred(cm, classes, title, save_path, dark_mode=None):
|
| 70 |
-
"""
|
| 71 |
-
Plot confusion matrix and adjust colors based on light/dark mode settings.
|
| 72 |
-
:param cm: Confusion matrix data.
|
| 73 |
-
:param classes: List of class labels.
|
| 74 |
-
:param title: Plot title.
|
| 75 |
-
:param save_path: Path to save the plot.
|
| 76 |
-
:param dark_mode: Boolean to toggle between light and dark modes. If None, use the current theme.
|
| 77 |
-
"""
|
| 78 |
-
|
| 79 |
-
# If dark_mode is None, try detecting it from rcParams (matplotlib theme)
|
| 80 |
-
if dark_mode is None:
|
| 81 |
-
dark_mode = plt.rcParams['axes.facecolor'] == '#333333' # Check if dark background is set
|
| 82 |
-
|
| 83 |
# Compute the average F1-score
|
| 84 |
avg_f1 = compute_f1_score(cm)
|
| 85 |
|
| 86 |
-
# Choose the color scheme based on the mode
|
| 87 |
-
if
|
| 88 |
plt.style.use('dark_background') # Use dark mode styling
|
| 89 |
text_color = 'white'
|
| 90 |
cmap = 'cividis' # Dark-mode-friendly colormap
|
|
@@ -95,7 +75,7 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path, dark_mode=None
|
|
| 95 |
|
| 96 |
plt.figure(figsize=(10, 10))
|
| 97 |
|
| 98 |
-
# Plot the confusion matrix with the
|
| 99 |
sns.heatmap(cm, cmap=cmap, cbar=True, linecolor='white', vmin=0, vmax=cm.max(), alpha=0.85)
|
| 100 |
|
| 101 |
# Add F1-score to the title
|
|
@@ -579,12 +559,21 @@ with gr.Blocks(css="""
|
|
| 579 |
task_complexity_dropdown = gr.Dropdown(label="Task Complexity (Number of Beams)", choices=[16, 32, 64, 128, 256], value=16)
|
| 580 |
|
| 581 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=500)
|
| 583 |
embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=500)
|
| 584 |
|
| 585 |
# Update the confusion matrices whenever sliders change
|
| 586 |
-
data_percentage_slider.change(fn=beam_prediction_task, inputs=[data_percentage_slider, task_complexity_dropdown], outputs=[raw_img_bp, embeddings_img_bp])
|
| 587 |
-
task_complexity_dropdown.change(fn=beam_prediction_task, inputs=[data_percentage_slider, task_complexity_dropdown], outputs=[raw_img_bp, embeddings_img_bp])
|
| 588 |
|
| 589 |
# Add a conclusion section at the bottom
|
| 590 |
gr.Markdown("""
|
|
|
|
| 15 |
import seaborn as sns
|
| 16 |
|
| 17 |
#################### BEAM PREDICTION #########################}
|
| 18 |
+
def beam_prediction_task(data_percentage, task_complexity, user_mode="light"):
|
| 19 |
# Folder naming convention based on input_type, data_percentage, and task_complexity
|
| 20 |
raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
|
| 21 |
embeddings_folder = f"images/embedding_{data_percentage/100:.1f}_{task_complexity}"
|
|
|
|
| 24 |
raw_cm = compute_average_confusion_matrix(raw_folder)
|
| 25 |
if raw_cm is not None:
|
| 26 |
raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
| 27 |
+
plot_confusion_matrix_beamPred(raw_cm, classes=np.arange(raw_cm.shape[0]), title=f"Raw Confusion Matrix\n({data_percentage}% data, {task_complexity} beams)", save_path=raw_cm_path, user_mode=user_mode)
|
| 28 |
raw_img = Image.open(raw_cm_path)
|
| 29 |
else:
|
| 30 |
raw_img = None
|
|
|
|
| 33 |
embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
|
| 34 |
if embeddings_cm is not None:
|
| 35 |
embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
| 36 |
+
plot_confusion_matrix_beamPred(embeddings_cm, classes=np.arange(embeddings_cm.shape[0]), title=f"Embeddings Confusion Matrix\n({data_percentage}% data, {task_complexity} beams)", save_path=embeddings_cm_path, user_mode=user_mode)
|
| 37 |
embeddings_img = Image.open(embeddings_cm_path)
|
| 38 |
else:
|
| 39 |
embeddings_img = None
|
| 40 |
|
| 41 |
return raw_img, embeddings_img
|
| 42 |
|
|
|
|
|
|
|
| 43 |
# Function to compute the F1-score based on the confusion matrix
|
| 44 |
def compute_f1_score(cm):
|
| 45 |
# Compute precision and recall
|
|
|
|
| 59 |
f1 = np.nan_to_num(f1) # Replace NaN with 0
|
| 60 |
return np.mean(f1) # Return the mean F1-score across all classes
|
| 61 |
|
| 62 |
+
def plot_confusion_matrix_beamPred(cm, classes, title, save_path, user_mode="light"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# Compute the average F1-score
|
| 64 |
avg_f1 = compute_f1_score(cm)
|
| 65 |
|
| 66 |
+
# Choose the color scheme based on the user's mode
|
| 67 |
+
if user_mode == 'dark':
|
| 68 |
plt.style.use('dark_background') # Use dark mode styling
|
| 69 |
text_color = 'white'
|
| 70 |
cmap = 'cividis' # Dark-mode-friendly colormap
|
|
|
|
| 75 |
|
| 76 |
plt.figure(figsize=(10, 10))
|
| 77 |
|
| 78 |
+
# Plot the confusion matrix with a colormap compatible for the mode
|
| 79 |
sns.heatmap(cm, cmap=cmap, cbar=True, linecolor='white', vmin=0, vmax=cm.max(), alpha=0.85)
|
| 80 |
|
| 81 |
# Add F1-score to the title
|
|
|
|
| 559 |
task_complexity_dropdown = gr.Dropdown(label="Task Complexity (Number of Beams)", choices=[16, 32, 64, 128, 256], value=16)
|
| 560 |
|
| 561 |
with gr.Row():
|
| 562 |
+
|
| 563 |
+
mode_input = gr.Textbox(visible=False) # Hidden input to capture user mode
|
| 564 |
+
gr.Markdown("""
|
| 565 |
+
<script>
|
| 566 |
+
const userPrefersDark = window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches;
|
| 567 |
+
document.querySelector('input[name="mode_input"]').value = userPrefersDark ? 'dark' : 'light';
|
| 568 |
+
</script>
|
| 569 |
+
""")
|
| 570 |
+
|
| 571 |
raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=500)
|
| 572 |
embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=500)
|
| 573 |
|
| 574 |
# Update the confusion matrices whenever sliders change
|
| 575 |
+
data_percentage_slider.change(fn=beam_prediction_task, inputs=[data_percentage_slider, task_complexity_dropdown, mode_input], outputs=[raw_img_bp, embeddings_img_bp])
|
| 576 |
+
task_complexity_dropdown.change(fn=beam_prediction_task, inputs=[data_percentage_slider, task_complexity_dropdown, mode_input], outputs=[raw_img_bp, embeddings_img_bp])
|
| 577 |
|
| 578 |
# Add a conclusion section at the bottom
|
| 579 |
gr.Markdown("""
|