wi-lab commited on
Commit
4410a6f
·
verified ·
1 Parent(s): d854d43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -39,8 +39,11 @@ def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None):
39
  Return a safe gr.update payload for a Dataframe (headers must match col_count).
40
  """
41
  df = pd.DataFrame(np_mat)
42
- # Round to 3 decimal places
43
  df = df.round(3)
 
 
 
44
  if labels is not None and len(labels) == df.shape[0]:
45
  df.index = labels
46
  if labels is not None and len(labels) == df.shape[1]:
@@ -58,7 +61,7 @@ def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarr
58
  # Set dark mode style
59
  plt.style.use('dark_background')
60
 
61
- fig, ax = plt.subplots(figsize=(6, 5), dpi=160, facecolor='#1e1e1e')
62
  ax.set_facecolor('#1e1e1e')
63
 
64
  im = ax.imshow(D, cmap="magma")
@@ -299,6 +302,14 @@ def run_compute(
299
  label_max_per_class=int(label_max_per_class),
300
  )
301
  D_np = D.detach().cpu().numpy()
 
 
 
 
 
 
 
 
302
 
303
  img = _plot_heatmap(D_np, labels=names)
304
  return (
 
39
  Return a safe gr.update payload for a Dataframe (headers must match col_count).
40
  """
41
  df = pd.DataFrame(np_mat)
42
+ # Format values to always show 3 decimal places
43
  df = df.round(3)
44
+ # Format each value to show exactly 3 decimal places
45
+ for col in df.columns:
46
+ df[col] = df[col].apply(lambda x: f"{x:.3f}")
47
  if labels is not None and len(labels) == df.shape[0]:
48
  df.index = labels
49
  if labels is not None and len(labels) == df.shape[1]:
 
61
  # Set dark mode style
62
  plt.style.use('dark_background')
63
 
64
+ fig, ax = plt.subplots(figsize=(6, 5), dpi=200, facecolor='#1e1e1e')
65
  ax.set_facecolor('#1e1e1e')
66
 
67
  im = ax.imshow(D, cmap="magma")
 
302
  label_max_per_class=int(label_max_per_class),
303
  )
304
  D_np = D.detach().cpu().numpy()
305
+
306
+ # Normalize distance matrix to [0, 1] range (min-max normalization)
307
+ d_min = D_np.min()
308
+ d_max = D_np.max()
309
+ if d_max > d_min: # Avoid division by zero
310
+ D_np = (D_np - d_min) / (d_max - d_min)
311
+ else:
312
+ D_np = np.zeros_like(D_np) # All values are the same, set to 0
313
 
314
  img = _plot_heatmap(D_np, labels=names)
315
  return (