Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,8 +39,11 @@ def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None):
|
|
| 39 |
Return a safe gr.update payload for a Dataframe (headers must match col_count).
|
| 40 |
"""
|
| 41 |
df = pd.DataFrame(np_mat)
|
| 42 |
-
#
|
| 43 |
df = df.round(3)
|
|
|
|
|
|
|
|
|
|
| 44 |
if labels is not None and len(labels) == df.shape[0]:
|
| 45 |
df.index = labels
|
| 46 |
if labels is not None and len(labels) == df.shape[1]:
|
|
@@ -58,7 +61,7 @@ def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarr
|
|
| 58 |
# Set dark mode style
|
| 59 |
plt.style.use('dark_background')
|
| 60 |
|
| 61 |
-
fig, ax = plt.subplots(figsize=(6, 5), dpi=
|
| 62 |
ax.set_facecolor('#1e1e1e')
|
| 63 |
|
| 64 |
im = ax.imshow(D, cmap="magma")
|
|
@@ -299,6 +302,14 @@ def run_compute(
|
|
| 299 |
label_max_per_class=int(label_max_per_class),
|
| 300 |
)
|
| 301 |
D_np = D.detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
img = _plot_heatmap(D_np, labels=names)
|
| 304 |
return (
|
|
|
|
| 39 |
Return a safe gr.update payload for a Dataframe (headers must match col_count).
|
| 40 |
"""
|
| 41 |
df = pd.DataFrame(np_mat)
|
| 42 |
+
# Format values to always show 3 decimal places
|
| 43 |
df = df.round(3)
|
| 44 |
+
# Format each value to show exactly 3 decimal places
|
| 45 |
+
for col in df.columns:
|
| 46 |
+
df[col] = df[col].apply(lambda x: f"{x:.3f}")
|
| 47 |
if labels is not None and len(labels) == df.shape[0]:
|
| 48 |
df.index = labels
|
| 49 |
if labels is not None and len(labels) == df.shape[1]:
|
|
|
|
| 61 |
# Set dark mode style
|
| 62 |
plt.style.use('dark_background')
|
| 63 |
|
| 64 |
+
fig, ax = plt.subplots(figsize=(6, 5), dpi=200, facecolor='#1e1e1e')
|
| 65 |
ax.set_facecolor('#1e1e1e')
|
| 66 |
|
| 67 |
im = ax.imshow(D, cmap="magma")
|
|
|
|
| 302 |
label_max_per_class=int(label_max_per_class),
|
| 303 |
)
|
| 304 |
D_np = D.detach().cpu().numpy()
|
| 305 |
+
|
| 306 |
+
# Normalize distance matrix to [0, 1] range (min-max normalization)
|
| 307 |
+
d_min = D_np.min()
|
| 308 |
+
d_max = D_np.max()
|
| 309 |
+
if d_max > d_min: # Avoid division by zero
|
| 310 |
+
D_np = (D_np - d_min) / (d_max - d_min)
|
| 311 |
+
else:
|
| 312 |
+
D_np = np.zeros_like(D_np) # All values are the same, set to 0
|
| 313 |
|
| 314 |
img = _plot_heatmap(D_np, labels=names)
|
| 315 |
return (
|