wi-lab's picture
Update app.py
7729549 verified
import os
import sys
import json
from typing import List, Tuple, Optional, Dict
# Use a non-GUI backend for Matplotlib
import matplotlib
matplotlib.use("Agg")
import numpy as np
import torch
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from io_demo import (
list_demo_tasks, list_demo_dataset_files, load_pt_dataset,
)
from distances_common import (
pairwise_centroid_distances,
pairwise_cosine_similarity_distances,
sliced_wasserstein_distance_matrix,
)
from embed_raw import build_raw_embeddings
from embed_umap import build_umap_embeddings
from embed_lwm import get_lwm_encoder, build_lwm_embeddings
# ------------------------
# Small helpers / logging
# ------------------------
def _log(msg: str):
print(msg, flush=True)
def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None):
"""
Return a safe gr.update payload for a Dataframe (headers must match col_count).
"""
df = pd.DataFrame(np_mat)
# Format values to always show 3 decimal places
df = df.round(3)
# Format each value to show exactly 3 decimal places
for col in df.columns:
df[col] = df[col].apply(lambda x: f"{x:.3f}")
if labels is not None and len(labels) == df.shape[0]:
df.index = labels
if labels is not None and len(labels) == df.shape[1]:
df.columns = labels
return gr.update(
value=df,
headers=list(df.columns),
col_count=(df.shape[1], "fixed"),
row_count=(df.shape[0], "fixed")
)
def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarray:
"""Return an RGB image (as numpy array) of the heatmap."""
# Set dark mode style
plt.style.use('dark_background')
fig, ax = plt.subplots(figsize=(6, 5), dpi=200, facecolor='#1e1e1e')
ax.set_facecolor('#1e1e1e')
im = ax.imshow(D, cmap="magma")
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.ax.tick_params(colors='white') # Light colorbar labels
if labels and len(labels) == D.shape[0]:
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels, rotation=60, ha="right", fontsize=8, color='white')
ax.set_yticklabels(labels, fontsize=8, color='white')
ax.set_title("Dataset Distance Matrix", color='white')
ax.grid(False)
# Set axis colors to light
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')
ax.tick_params(colors='white', which='both')
fig.tight_layout()
# Render and grab RGBA buffer (works across Matplotlib versions)
fig.canvas.draw()
width, height = fig.canvas.get_width_height()
buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
img_rgba = buf.reshape((height, width, 4))
img_rgb = img_rgba[..., :3].copy()
plt.close(fig)
return img_rgb
def _load_uploaded_files(file_objs: List) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
"""Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional)."""
out = []
for f in (file_objs or []):
path = getattr(f, "name", f)
try:
obj = torch.load(path, map_location="cpu")
ch = obj["channels"]
y = obj.get("labels", None)
out.append((ch, y, os.path.basename(path)))
except Exception as e:
_log(f"[WARN] Failed to load {path}: {e}")
return out
def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
out = []
for p in paths or []:
try:
ch, y = load_pt_dataset(p)
# scenario folder name as label
out.append((ch, y, os.path.basename(os.path.dirname(p))))
except Exception as e:
_log(f"[WARN] Failed to load demo dataset {p}: {e}")
return out
def _compute_embeddings(
framework: str,
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
n_per_dataset: int,
label_aware: bool,
umap_cfg: Dict
):
"""
Returns:
embs: torch.Tensor [D, n, d]
labels_per_ds: Optional[List[torch.Tensor]]
"""
if framework == "RAW":
embs, labels_per_ds = build_raw_embeddings(datasets, n_per_dataset, label_aware)
return embs, labels_per_ds
if framework == "UMAP":
embs, labels_per_ds = build_umap_embeddings(
datasets=datasets,
n_per_dataset=n_per_dataset,
label_aware=label_aware,
umap_mode=umap_cfg.get("mode", "supervised"),
umap_kwargs=umap_cfg.get("kwargs", {}),
channel_representation=umap_cfg.get("repr", "raw"),
angle_delay_bins=int(umap_cfg.get("angle_delay_bins", 16)),
)
return embs, labels_per_ds
if framework == "LWM":
model = get_lwm_encoder()
if model is None:
_log("[WARN] LWM encoder not available; falling back to RAW embeddings.")
embs, labels_per_ds = build_raw_embeddings(datasets, n_per_dataset, label_aware)
else:
embs, labels_per_ds = build_lwm_embeddings(
model=model,
datasets=datasets,
n_per_dataset=n_per_dataset,
label_aware=label_aware
)
return embs, labels_per_ds
raise ValueError(f"Unknown framework: {framework}")
def _compute_distance_matrix(
embs: torch.Tensor,
distance_mode: str,
num_projections: int,
label_aware: bool,
labels_per_ds: Optional[List[torch.Tensor]],
label_weighting: str,
label_max_per_class: int
) -> torch.Tensor:
"""
embs: [D, n, d]
"""
if distance_mode == "euclidean_centroid" and not label_aware:
cents = embs.mean(dim=1) # [D, d]
return pairwise_centroid_distances(cents)
if distance_mode == "cosine_similarity" and not label_aware:
cents = embs.mean(dim=1) # [D, d]
return pairwise_cosine_similarity_distances(cents)
# Sliced Wasserstein (supports label-aware)
return sliced_wasserstein_distance_matrix(
embs,
num_projections=num_projections,
labels_per_ds=labels_per_ds,
label_aware=label_aware,
label_weighting=label_weighting,
label_max_per_class=label_max_per_class
)
# ------------------------
# Gradio callbacks
# ------------------------
def refresh_demo_tasks():
tasks = list_demo_tasks()
return gr.update(choices=tasks, value=(tasks[0] if tasks else None))
def refresh_demo_scenarios(task: str):
if not task:
return gr.update(choices=[], value=[])
files = list_demo_dataset_files(task)
default = files[:3] if len(files) >= 3 else files
return gr.update(choices=files, value=default)
def run_compute(
framework: str,
distance_mode: str,
label_aware: bool,
label_weighting: str,
label_max_per_class: int,
num_projections: int,
n_eval_per_dataset: int,
demo_task: str,
demo_files: List[str],
uploaded_files: List,
umap_mode: str,
umap_n_components: int,
umap_n_neighbors: int,
umap_min_dist: float,
umap_metric: str,
umap_spread: float,
umap_learning_rate: float,
umap_n_epochs: int,
umap_negative_sample_rate: int,
umap_init: str,
umap_densmap: bool,
umap_set_op_mix_ratio: float,
umap_local_connectivity: float,
umap_repulsion_strength: float,
umap_random_state: int,
channel_representation: str,
angle_delay_bins: int,
):
datasets = []
if demo_task and demo_files:
datasets.extend(_load_demo_files(demo_files))
datasets.extend(_load_uploaded_files(uploaded_files))
if len(datasets) < 2:
return (
gr.update(value="Please provide at least 2 datasets (demo or upload)."),
_matrix_payload(np.zeros((0, 0))),
None
)
names = [name for _, _, name in datasets]
umap_kwargs = dict(
n_components=int(umap_n_components),
n_neighbors=int(umap_n_neighbors),
min_dist=float(umap_min_dist),
metric=umap_metric,
spread=float(umap_spread),
learning_rate=float(umap_learning_rate),
n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs),
negative_sample_rate=int(umap_negative_sample_rate),
init=umap_init,
densmap=bool(umap_densmap),
set_op_mix_ratio=float(umap_set_op_mix_ratio),
local_connectivity=float(umap_local_connectivity),
repulsion_strength=float(umap_repulsion_strength),
random_state=int(umap_random_state),
target_metric="categorical",
target_weight=0.5,
)
embs, labels_per_ds = _compute_embeddings(
framework=framework,
datasets=datasets,
n_per_dataset=int(n_eval_per_dataset),
label_aware=bool(label_aware),
umap_cfg={
"mode": umap_mode,
"kwargs": umap_kwargs,
"repr": channel_representation,
"angle_delay_bins": int(angle_delay_bins),
}
)
D = _compute_distance_matrix(
embs=embs,
distance_mode=distance_mode,
num_projections=int(num_projections),
label_aware=bool(label_aware),
labels_per_ds=labels_per_ds,
label_weighting=label_weighting,
label_max_per_class=int(label_max_per_class),
)
D_np = D.detach().cpu().numpy()
# Normalize distance matrix to [0, 1] range (min-max normalization)
d_min = D_np.min()
d_max = D_np.max()
if d_max > d_min: # Avoid division by zero
D_np = (D_np - d_min) / (d_max - d_min)
else:
D_np = np.zeros_like(D_np) # All values are the same, set to 0
img = _plot_heatmap(D_np, labels=names)
return (
gr.update(value="Done βœ…"),
_matrix_payload(D_np, labels=names),
img
)
# ------------------------
# UI
# ------------------------
with gr.Blocks(title="Dataset Distancing Lab") as demo:
gr.Markdown(
"""
# Dataset Distancing Lab
Compute distances between datasets using **RAW**, **UMAP**, or **LWM** embeddings.
Upload your `.pt`/`.p` datasets or try the built-in samples under `data/{task}/{scenario}/...`.
**Format:** each file should be a Torch file with keys:
- `channels`: `Tensor[N, ...]` (complex supported; real+imag will be concatenated)
- `labels` (optional): `Tensor[N]`
"""
)
with gr.Accordion("πŸ“š Citation", open=False):
gr.Markdown(
"""
If you use this lab or methods in your work, please cite:
```bibtex
@INPROCEEDINGS{10942657,
author={Morais, JoΓ£o and Alikhani, Sadjad and Malhotra, Akshay and Hamidi-Rad, Shahab and Alkhateeb, Ahmed},
booktitle={2024 58th Asilomar Conference on Signals, Systems, and Computers},
title={A Dataset Similarity Evaluation Framework for Wireless Communications and Sensing},
year={2024},
volume={},
number={},
pages={1144-1149},
keywords={Wireless communication;Dimensionality reduction;Adaptation models;Wireless sensor networks;Nearest neighbor methods;Extraterrestrial measurements;Data structures;Distance measurement;Data models;Sensors},
doi={10.1109/IEEECONF60004.2024.10942657}}
```
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
gr.Markdown("### Framework & Distance")
framework = gr.Radio(
choices=["RAW", "UMAP", "LWM"],
value="RAW",
label="Framework",
)
distance_mode = gr.Radio(
choices=["sliced_wasserstein", "euclidean_centroid", "cosine_similarity"],
value="sliced_wasserstein",
label="Distance Mode"
)
label_aware = gr.Checkbox(value=True, label="Label-aware (supported by SW distance)")
label_weighting = gr.Dropdown(
choices=["uniform", "support"],
value="uniform",
label="Label weighting"
)
label_max_per_class = gr.Number(value=1e10, precision=0, label="Max samples / class")
num_projections = gr.Slider(8, 256, value=64, step=1, label="SW #projections")
n_eval_per_dataset = gr.Slider(32, 4096, value=1024, step=32, label="Samples per dataset")
gr.Markdown("### UMAP (only if Framework=UMAP)")
umap_mode = gr.Dropdown(["unsupervised", "supervised"], value="supervised", label="UMAP Mode")
channel_representation = gr.Dropdown(["raw", "angle_delay"], value="raw", label="Channel representation")
angle_delay_bins = gr.Slider(4, 128, value=16, step=1, label="Angle-delay bins (if used)")
with gr.Accordion("Advanced UMAP settings", open=False):
umap_n_components = gr.Slider(2, 256, value=128, step=1, label="n_components")
umap_n_neighbors = gr.Slider(2, 128, value=32, step=1, label="n_neighbors")
umap_min_dist = gr.Slider(0.0, 0.99, value=0.1, step=0.01, label="min_dist")
umap_metric = gr.Dropdown(
["euclidean", "cosine", "manhattan", "chebyshev", "correlation"],
value="euclidean", label="metric"
)
umap_spread = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="spread")
umap_learning_rate = gr.Slider(0.1, 10.0, value=1.0, step=0.1, label="learning_rate")
umap_n_epochs = gr.Number(value=0, precision=0, label="n_epochs (0 = auto)")
umap_negative_sample_rate = gr.Slider(1, 50, value=5, step=1, label="negative_sample_rate")
umap_init = gr.Dropdown(["spectral", "random"], value="spectral", label="init")
umap_densmap = gr.Checkbox(value=False, label="densMAP")
umap_set_op_mix_ratio = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="set_op_mix_ratio")
umap_local_connectivity = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="local_connectivity")
umap_repulsion_strength = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="repulsion_strength")
umap_random_state = gr.Number(value=42, precision=0, label="random_state")
with gr.Column(scale=1, min_width=320):
gr.Markdown("### Demo datasets (data/{task}/{scenario}/...)")
demo_task = gr.Dropdown(choices=[], value=None, label="Task", interactive=True)
demo_select = gr.CheckboxGroup(choices=[], value=[], label="Scenarios (files inside each scenario)")
refresh = gr.Button("πŸ”„ Refresh Demo Lists")
refresh.click(
fn=refresh_demo_tasks,
inputs=[],
outputs=[demo_task]
)
demo_task.change(
fn=refresh_demo_scenarios,
inputs=[demo_task],
outputs=[demo_select]
)
gr.Markdown("### Or upload your own")
uploads = gr.Files(
label="Upload multiple .pt/.p datasets",
file_count="multiple",
file_types=[".pt", ".p"]
)
run_btn = gr.Button("πŸš€ Compute distances", variant="primary")
status = gr.Markdown("")
with gr.Column(scale=2):
gr.Markdown("### Distance Matrix (Table)")
matrix_out = gr.Dataframe(
value=None,
headers=None,
interactive=False,
wrap=True,
row_count=(0, "dynamic"),
col_count=(0, "dynamic"),
label="Distances"
)
gr.Markdown("### Distance Matrix (Heatmap)")
heatmap = gr.Image(type="numpy", interactive=False)
run_btn.click(
fn=run_compute,
inputs=[
framework, distance_mode, label_aware, label_weighting, label_max_per_class,
num_projections, n_eval_per_dataset, demo_task, demo_select, uploads,
umap_mode, umap_n_components, umap_n_neighbors, umap_min_dist, umap_metric,
umap_spread, umap_learning_rate, umap_n_epochs, umap_negative_sample_rate,
umap_init, umap_densmap, umap_set_op_mix_ratio, umap_local_connectivity,
umap_repulsion_strength, umap_random_state, channel_representation, angle_delay_bins
],
outputs=[status, matrix_out, heatmap]
)
if __name__ == "__main__":
demo.launch()