Spaces:
Running
Running
| import os | |
| import sys | |
| import json | |
| from typing import List, Tuple, Optional, Dict | |
| # Use a non-GUI backend for Matplotlib | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from io_demo import ( | |
| list_demo_tasks, list_demo_dataset_files, load_pt_dataset, | |
| ) | |
| from distances_common import ( | |
| pairwise_centroid_distances, | |
| pairwise_cosine_similarity_distances, | |
| sliced_wasserstein_distance_matrix, | |
| ) | |
| from embed_raw import build_raw_embeddings | |
| from embed_umap import build_umap_embeddings | |
| from embed_lwm import get_lwm_encoder, build_lwm_embeddings | |
| # ------------------------ | |
| # Small helpers / logging | |
| # ------------------------ | |
| def _log(msg: str): | |
| print(msg, flush=True) | |
| def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None): | |
| """ | |
| Return a safe gr.update payload for a Dataframe (headers must match col_count). | |
| """ | |
| df = pd.DataFrame(np_mat) | |
| # Format values to always show 3 decimal places | |
| df = df.round(3) | |
| # Format each value to show exactly 3 decimal places | |
| for col in df.columns: | |
| df[col] = df[col].apply(lambda x: f"{x:.3f}") | |
| if labels is not None and len(labels) == df.shape[0]: | |
| df.index = labels | |
| if labels is not None and len(labels) == df.shape[1]: | |
| df.columns = labels | |
| return gr.update( | |
| value=df, | |
| headers=list(df.columns), | |
| col_count=(df.shape[1], "fixed"), | |
| row_count=(df.shape[0], "fixed") | |
| ) | |
| def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarray: | |
| """Return an RGB image (as numpy array) of the heatmap.""" | |
| # Set dark mode style | |
| plt.style.use('dark_background') | |
| fig, ax = plt.subplots(figsize=(6, 5), dpi=200, facecolor='#1e1e1e') | |
| ax.set_facecolor('#1e1e1e') | |
| im = ax.imshow(D, cmap="magma") | |
| cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| cbar.ax.tick_params(colors='white') # Light colorbar labels | |
| if labels and len(labels) == D.shape[0]: | |
| ax.set_xticks(np.arange(len(labels))) | |
| ax.set_yticks(np.arange(len(labels))) | |
| ax.set_xticklabels(labels, rotation=60, ha="right", fontsize=8, color='white') | |
| ax.set_yticklabels(labels, fontsize=8, color='white') | |
| ax.set_title("Dataset Distance Matrix", color='white') | |
| ax.grid(False) | |
| # Set axis colors to light | |
| ax.spines['bottom'].set_color('white') | |
| ax.spines['top'].set_color('white') | |
| ax.spines['right'].set_color('white') | |
| ax.spines['left'].set_color('white') | |
| ax.tick_params(colors='white', which='both') | |
| fig.tight_layout() | |
| # Render and grab RGBA buffer (works across Matplotlib versions) | |
| fig.canvas.draw() | |
| width, height = fig.canvas.get_width_height() | |
| buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
| img_rgba = buf.reshape((height, width, 4)) | |
| img_rgb = img_rgba[..., :3].copy() | |
| plt.close(fig) | |
| return img_rgb | |
| def _load_uploaded_files(file_objs: List) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]: | |
| """Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional).""" | |
| out = [] | |
| for f in (file_objs or []): | |
| path = getattr(f, "name", f) | |
| try: | |
| obj = torch.load(path, map_location="cpu") | |
| ch = obj["channels"] | |
| y = obj.get("labels", None) | |
| out.append((ch, y, os.path.basename(path))) | |
| except Exception as e: | |
| _log(f"[WARN] Failed to load {path}: {e}") | |
| return out | |
| def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]: | |
| out = [] | |
| for p in paths or []: | |
| try: | |
| ch, y = load_pt_dataset(p) | |
| # scenario folder name as label | |
| out.append((ch, y, os.path.basename(os.path.dirname(p)))) | |
| except Exception as e: | |
| _log(f"[WARN] Failed to load demo dataset {p}: {e}") | |
| return out | |
| def _compute_embeddings( | |
| framework: str, | |
| datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], | |
| n_per_dataset: int, | |
| label_aware: bool, | |
| umap_cfg: Dict | |
| ): | |
| """ | |
| Returns: | |
| embs: torch.Tensor [D, n, d] | |
| labels_per_ds: Optional[List[torch.Tensor]] | |
| """ | |
| if framework == "RAW": | |
| embs, labels_per_ds = build_raw_embeddings(datasets, n_per_dataset, label_aware) | |
| return embs, labels_per_ds | |
| if framework == "UMAP": | |
| embs, labels_per_ds = build_umap_embeddings( | |
| datasets=datasets, | |
| n_per_dataset=n_per_dataset, | |
| label_aware=label_aware, | |
| umap_mode=umap_cfg.get("mode", "supervised"), | |
| umap_kwargs=umap_cfg.get("kwargs", {}), | |
| channel_representation=umap_cfg.get("repr", "raw"), | |
| angle_delay_bins=int(umap_cfg.get("angle_delay_bins", 16)), | |
| ) | |
| return embs, labels_per_ds | |
| if framework == "LWM": | |
| model = get_lwm_encoder() | |
| if model is None: | |
| _log("[WARN] LWM encoder not available; falling back to RAW embeddings.") | |
| embs, labels_per_ds = build_raw_embeddings(datasets, n_per_dataset, label_aware) | |
| else: | |
| embs, labels_per_ds = build_lwm_embeddings( | |
| model=model, | |
| datasets=datasets, | |
| n_per_dataset=n_per_dataset, | |
| label_aware=label_aware | |
| ) | |
| return embs, labels_per_ds | |
| raise ValueError(f"Unknown framework: {framework}") | |
| def _compute_distance_matrix( | |
| embs: torch.Tensor, | |
| distance_mode: str, | |
| num_projections: int, | |
| label_aware: bool, | |
| labels_per_ds: Optional[List[torch.Tensor]], | |
| label_weighting: str, | |
| label_max_per_class: int | |
| ) -> torch.Tensor: | |
| """ | |
| embs: [D, n, d] | |
| """ | |
| if distance_mode == "euclidean_centroid" and not label_aware: | |
| cents = embs.mean(dim=1) # [D, d] | |
| return pairwise_centroid_distances(cents) | |
| if distance_mode == "cosine_similarity" and not label_aware: | |
| cents = embs.mean(dim=1) # [D, d] | |
| return pairwise_cosine_similarity_distances(cents) | |
| # Sliced Wasserstein (supports label-aware) | |
| return sliced_wasserstein_distance_matrix( | |
| embs, | |
| num_projections=num_projections, | |
| labels_per_ds=labels_per_ds, | |
| label_aware=label_aware, | |
| label_weighting=label_weighting, | |
| label_max_per_class=label_max_per_class | |
| ) | |
| # ------------------------ | |
| # Gradio callbacks | |
| # ------------------------ | |
| def refresh_demo_tasks(): | |
| tasks = list_demo_tasks() | |
| return gr.update(choices=tasks, value=(tasks[0] if tasks else None)) | |
| def refresh_demo_scenarios(task: str): | |
| if not task: | |
| return gr.update(choices=[], value=[]) | |
| files = list_demo_dataset_files(task) | |
| default = files[:3] if len(files) >= 3 else files | |
| return gr.update(choices=files, value=default) | |
| def run_compute( | |
| framework: str, | |
| distance_mode: str, | |
| label_aware: bool, | |
| label_weighting: str, | |
| label_max_per_class: int, | |
| num_projections: int, | |
| n_eval_per_dataset: int, | |
| demo_task: str, | |
| demo_files: List[str], | |
| uploaded_files: List, | |
| umap_mode: str, | |
| umap_n_components: int, | |
| umap_n_neighbors: int, | |
| umap_min_dist: float, | |
| umap_metric: str, | |
| umap_spread: float, | |
| umap_learning_rate: float, | |
| umap_n_epochs: int, | |
| umap_negative_sample_rate: int, | |
| umap_init: str, | |
| umap_densmap: bool, | |
| umap_set_op_mix_ratio: float, | |
| umap_local_connectivity: float, | |
| umap_repulsion_strength: float, | |
| umap_random_state: int, | |
| channel_representation: str, | |
| angle_delay_bins: int, | |
| ): | |
| datasets = [] | |
| if demo_task and demo_files: | |
| datasets.extend(_load_demo_files(demo_files)) | |
| datasets.extend(_load_uploaded_files(uploaded_files)) | |
| if len(datasets) < 2: | |
| return ( | |
| gr.update(value="Please provide at least 2 datasets (demo or upload)."), | |
| _matrix_payload(np.zeros((0, 0))), | |
| None | |
| ) | |
| names = [name for _, _, name in datasets] | |
| umap_kwargs = dict( | |
| n_components=int(umap_n_components), | |
| n_neighbors=int(umap_n_neighbors), | |
| min_dist=float(umap_min_dist), | |
| metric=umap_metric, | |
| spread=float(umap_spread), | |
| learning_rate=float(umap_learning_rate), | |
| n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs), | |
| negative_sample_rate=int(umap_negative_sample_rate), | |
| init=umap_init, | |
| densmap=bool(umap_densmap), | |
| set_op_mix_ratio=float(umap_set_op_mix_ratio), | |
| local_connectivity=float(umap_local_connectivity), | |
| repulsion_strength=float(umap_repulsion_strength), | |
| random_state=int(umap_random_state), | |
| target_metric="categorical", | |
| target_weight=0.5, | |
| ) | |
| embs, labels_per_ds = _compute_embeddings( | |
| framework=framework, | |
| datasets=datasets, | |
| n_per_dataset=int(n_eval_per_dataset), | |
| label_aware=bool(label_aware), | |
| umap_cfg={ | |
| "mode": umap_mode, | |
| "kwargs": umap_kwargs, | |
| "repr": channel_representation, | |
| "angle_delay_bins": int(angle_delay_bins), | |
| } | |
| ) | |
| D = _compute_distance_matrix( | |
| embs=embs, | |
| distance_mode=distance_mode, | |
| num_projections=int(num_projections), | |
| label_aware=bool(label_aware), | |
| labels_per_ds=labels_per_ds, | |
| label_weighting=label_weighting, | |
| label_max_per_class=int(label_max_per_class), | |
| ) | |
| D_np = D.detach().cpu().numpy() | |
| # Normalize distance matrix to [0, 1] range (min-max normalization) | |
| d_min = D_np.min() | |
| d_max = D_np.max() | |
| if d_max > d_min: # Avoid division by zero | |
| D_np = (D_np - d_min) / (d_max - d_min) | |
| else: | |
| D_np = np.zeros_like(D_np) # All values are the same, set to 0 | |
| img = _plot_heatmap(D_np, labels=names) | |
| return ( | |
| gr.update(value="Done β "), | |
| _matrix_payload(D_np, labels=names), | |
| img | |
| ) | |
| # ------------------------ | |
| # UI | |
| # ------------------------ | |
| with gr.Blocks(title="Dataset Distancing Lab") as demo: | |
| gr.Markdown( | |
| """ | |
| # Dataset Distancing Lab | |
| Compute distances between datasets using **RAW**, **UMAP**, or **LWM** embeddings. | |
| Upload your `.pt`/`.p` datasets or try the built-in samples under `data/{task}/{scenario}/...`. | |
| **Format:** each file should be a Torch file with keys: | |
| - `channels`: `Tensor[N, ...]` (complex supported; real+imag will be concatenated) | |
| - `labels` (optional): `Tensor[N]` | |
| """ | |
| ) | |
| with gr.Accordion("π Citation", open=False): | |
| gr.Markdown( | |
| """ | |
| If you use this lab or methods in your work, please cite: | |
| ```bibtex | |
| @INPROCEEDINGS{10942657, | |
| author={Morais, JoΓ£o and Alikhani, Sadjad and Malhotra, Akshay and Hamidi-Rad, Shahab and Alkhateeb, Ahmed}, | |
| booktitle={2024 58th Asilomar Conference on Signals, Systems, and Computers}, | |
| title={A Dataset Similarity Evaluation Framework for Wireless Communications and Sensing}, | |
| year={2024}, | |
| volume={}, | |
| number={}, | |
| pages={1144-1149}, | |
| keywords={Wireless communication;Dimensionality reduction;Adaptation models;Wireless sensor networks;Nearest neighbor methods;Extraterrestrial measurements;Data structures;Distance measurement;Data models;Sensors}, | |
| doi={10.1109/IEEECONF60004.2024.10942657}} | |
| ``` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| gr.Markdown("### Framework & Distance") | |
| framework = gr.Radio( | |
| choices=["RAW", "UMAP", "LWM"], | |
| value="RAW", | |
| label="Framework", | |
| ) | |
| distance_mode = gr.Radio( | |
| choices=["sliced_wasserstein", "euclidean_centroid", "cosine_similarity"], | |
| value="sliced_wasserstein", | |
| label="Distance Mode" | |
| ) | |
| label_aware = gr.Checkbox(value=True, label="Label-aware (supported by SW distance)") | |
| label_weighting = gr.Dropdown( | |
| choices=["uniform", "support"], | |
| value="uniform", | |
| label="Label weighting" | |
| ) | |
| label_max_per_class = gr.Number(value=1e10, precision=0, label="Max samples / class") | |
| num_projections = gr.Slider(8, 256, value=64, step=1, label="SW #projections") | |
| n_eval_per_dataset = gr.Slider(32, 4096, value=1024, step=32, label="Samples per dataset") | |
| gr.Markdown("### UMAP (only if Framework=UMAP)") | |
| umap_mode = gr.Dropdown(["unsupervised", "supervised"], value="supervised", label="UMAP Mode") | |
| channel_representation = gr.Dropdown(["raw", "angle_delay"], value="raw", label="Channel representation") | |
| angle_delay_bins = gr.Slider(4, 128, value=16, step=1, label="Angle-delay bins (if used)") | |
| with gr.Accordion("Advanced UMAP settings", open=False): | |
| umap_n_components = gr.Slider(2, 256, value=128, step=1, label="n_components") | |
| umap_n_neighbors = gr.Slider(2, 128, value=32, step=1, label="n_neighbors") | |
| umap_min_dist = gr.Slider(0.0, 0.99, value=0.1, step=0.01, label="min_dist") | |
| umap_metric = gr.Dropdown( | |
| ["euclidean", "cosine", "manhattan", "chebyshev", "correlation"], | |
| value="euclidean", label="metric" | |
| ) | |
| umap_spread = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="spread") | |
| umap_learning_rate = gr.Slider(0.1, 10.0, value=1.0, step=0.1, label="learning_rate") | |
| umap_n_epochs = gr.Number(value=0, precision=0, label="n_epochs (0 = auto)") | |
| umap_negative_sample_rate = gr.Slider(1, 50, value=5, step=1, label="negative_sample_rate") | |
| umap_init = gr.Dropdown(["spectral", "random"], value="spectral", label="init") | |
| umap_densmap = gr.Checkbox(value=False, label="densMAP") | |
| umap_set_op_mix_ratio = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="set_op_mix_ratio") | |
| umap_local_connectivity = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="local_connectivity") | |
| umap_repulsion_strength = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="repulsion_strength") | |
| umap_random_state = gr.Number(value=42, precision=0, label="random_state") | |
| with gr.Column(scale=1, min_width=320): | |
| gr.Markdown("### Demo datasets (data/{task}/{scenario}/...)") | |
| demo_task = gr.Dropdown(choices=[], value=None, label="Task", interactive=True) | |
| demo_select = gr.CheckboxGroup(choices=[], value=[], label="Scenarios (files inside each scenario)") | |
| refresh = gr.Button("π Refresh Demo Lists") | |
| refresh.click( | |
| fn=refresh_demo_tasks, | |
| inputs=[], | |
| outputs=[demo_task] | |
| ) | |
| demo_task.change( | |
| fn=refresh_demo_scenarios, | |
| inputs=[demo_task], | |
| outputs=[demo_select] | |
| ) | |
| gr.Markdown("### Or upload your own") | |
| uploads = gr.Files( | |
| label="Upload multiple .pt/.p datasets", | |
| file_count="multiple", | |
| file_types=[".pt", ".p"] | |
| ) | |
| run_btn = gr.Button("π Compute distances", variant="primary") | |
| status = gr.Markdown("") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Distance Matrix (Table)") | |
| matrix_out = gr.Dataframe( | |
| value=None, | |
| headers=None, | |
| interactive=False, | |
| wrap=True, | |
| row_count=(0, "dynamic"), | |
| col_count=(0, "dynamic"), | |
| label="Distances" | |
| ) | |
| gr.Markdown("### Distance Matrix (Heatmap)") | |
| heatmap = gr.Image(type="numpy", interactive=False) | |
| run_btn.click( | |
| fn=run_compute, | |
| inputs=[ | |
| framework, distance_mode, label_aware, label_weighting, label_max_per_class, | |
| num_projections, n_eval_per_dataset, demo_task, demo_select, uploads, | |
| umap_mode, umap_n_components, umap_n_neighbors, umap_min_dist, umap_metric, | |
| umap_spread, umap_learning_rate, umap_n_epochs, umap_negative_sample_rate, | |
| umap_init, umap_densmap, umap_set_op_mix_ratio, umap_local_connectivity, | |
| umap_repulsion_strength, umap_random_state, channel_representation, angle_delay_bins | |
| ], | |
| outputs=[status, matrix_out, heatmap] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |