import os import sys import json from typing import List, Tuple, Optional, Dict # Use a non-GUI backend for Matplotlib import matplotlib matplotlib.use("Agg") import numpy as np import torch import gradio as gr import matplotlib.pyplot as plt import pandas as pd from io_demo import ( list_demo_tasks, list_demo_dataset_files, load_pt_dataset, ) from distances_common import ( pairwise_centroid_distances, pairwise_cosine_similarity_distances, sliced_wasserstein_distance_matrix, ) from embed_raw import build_raw_embeddings from embed_umap import build_umap_embeddings from embed_lwm import get_lwm_encoder, build_lwm_embeddings # ------------------------ # Small helpers / logging # ------------------------ def _log(msg: str): print(msg, flush=True) def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None): """ Return a safe gr.update payload for a Dataframe (headers must match col_count). """ df = pd.DataFrame(np_mat) # Format values to always show 3 decimal places df = df.round(3) # Format each value to show exactly 3 decimal places for col in df.columns: df[col] = df[col].apply(lambda x: f"{x:.3f}") if labels is not None and len(labels) == df.shape[0]: df.index = labels if labels is not None and len(labels) == df.shape[1]: df.columns = labels return gr.update( value=df, headers=list(df.columns), col_count=(df.shape[1], "fixed"), row_count=(df.shape[0], "fixed") ) def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarray: """Return an RGB image (as numpy array) of the heatmap.""" # Set dark mode style plt.style.use('dark_background') fig, ax = plt.subplots(figsize=(6, 5), dpi=200, facecolor='#1e1e1e') ax.set_facecolor('#1e1e1e') im = ax.imshow(D, cmap="magma") cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.ax.tick_params(colors='white') # Light colorbar labels if labels and len(labels) == D.shape[0]: ax.set_xticks(np.arange(len(labels))) ax.set_yticks(np.arange(len(labels))) ax.set_xticklabels(labels, rotation=60, ha="right", fontsize=8, color='white') ax.set_yticklabels(labels, fontsize=8, color='white') ax.set_title("Dataset Distance Matrix", color='white') ax.grid(False) # Set axis colors to light ax.spines['bottom'].set_color('white') ax.spines['top'].set_color('white') ax.spines['right'].set_color('white') ax.spines['left'].set_color('white') ax.tick_params(colors='white', which='both') fig.tight_layout() # Render and grab RGBA buffer (works across Matplotlib versions) fig.canvas.draw() width, height = fig.canvas.get_width_height() buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) img_rgba = buf.reshape((height, width, 4)) img_rgb = img_rgba[..., :3].copy() plt.close(fig) return img_rgb def _load_uploaded_files(file_objs: List) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]: """Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional).""" out = [] for f in (file_objs or []): path = getattr(f, "name", f) try: obj = torch.load(path, map_location="cpu") ch = obj["channels"] y = obj.get("labels", None) out.append((ch, y, os.path.basename(path))) except Exception as e: _log(f"[WARN] Failed to load {path}: {e}") return out def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]: out = [] for p in paths or []: try: ch, y = load_pt_dataset(p) # scenario folder name as label out.append((ch, y, os.path.basename(os.path.dirname(p)))) except Exception as e: _log(f"[WARN] Failed to load demo dataset {p}: {e}") return out def _compute_embeddings( framework: str, datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], n_per_dataset: int, label_aware: bool, umap_cfg: Dict ): """ Returns: embs: torch.Tensor [D, n, d] labels_per_ds: Optional[List[torch.Tensor]] """ if framework == "RAW": embs, labels_per_ds = build_raw_embeddings(datasets, n_per_dataset, label_aware) return embs, labels_per_ds if framework == "UMAP": embs, labels_per_ds = build_umap_embeddings( datasets=datasets, n_per_dataset=n_per_dataset, label_aware=label_aware, umap_mode=umap_cfg.get("mode", "supervised"), umap_kwargs=umap_cfg.get("kwargs", {}), channel_representation=umap_cfg.get("repr", "raw"), angle_delay_bins=int(umap_cfg.get("angle_delay_bins", 16)), ) return embs, labels_per_ds if framework == "LWM": model = get_lwm_encoder() if model is None: _log("[WARN] LWM encoder not available; falling back to RAW embeddings.") embs, labels_per_ds = build_raw_embeddings(datasets, n_per_dataset, label_aware) else: embs, labels_per_ds = build_lwm_embeddings( model=model, datasets=datasets, n_per_dataset=n_per_dataset, label_aware=label_aware ) return embs, labels_per_ds raise ValueError(f"Unknown framework: {framework}") def _compute_distance_matrix( embs: torch.Tensor, distance_mode: str, num_projections: int, label_aware: bool, labels_per_ds: Optional[List[torch.Tensor]], label_weighting: str, label_max_per_class: int ) -> torch.Tensor: """ embs: [D, n, d] """ if distance_mode == "euclidean_centroid" and not label_aware: cents = embs.mean(dim=1) # [D, d] return pairwise_centroid_distances(cents) if distance_mode == "cosine_similarity" and not label_aware: cents = embs.mean(dim=1) # [D, d] return pairwise_cosine_similarity_distances(cents) # Sliced Wasserstein (supports label-aware) return sliced_wasserstein_distance_matrix( embs, num_projections=num_projections, labels_per_ds=labels_per_ds, label_aware=label_aware, label_weighting=label_weighting, label_max_per_class=label_max_per_class ) # ------------------------ # Gradio callbacks # ------------------------ def refresh_demo_tasks(): tasks = list_demo_tasks() return gr.update(choices=tasks, value=(tasks[0] if tasks else None)) def refresh_demo_scenarios(task: str): if not task: return gr.update(choices=[], value=[]) files = list_demo_dataset_files(task) default = files[:3] if len(files) >= 3 else files return gr.update(choices=files, value=default) def run_compute( framework: str, distance_mode: str, label_aware: bool, label_weighting: str, label_max_per_class: int, num_projections: int, n_eval_per_dataset: int, demo_task: str, demo_files: List[str], uploaded_files: List, umap_mode: str, umap_n_components: int, umap_n_neighbors: int, umap_min_dist: float, umap_metric: str, umap_spread: float, umap_learning_rate: float, umap_n_epochs: int, umap_negative_sample_rate: int, umap_init: str, umap_densmap: bool, umap_set_op_mix_ratio: float, umap_local_connectivity: float, umap_repulsion_strength: float, umap_random_state: int, channel_representation: str, angle_delay_bins: int, ): datasets = [] if demo_task and demo_files: datasets.extend(_load_demo_files(demo_files)) datasets.extend(_load_uploaded_files(uploaded_files)) if len(datasets) < 2: return ( gr.update(value="Please provide at least 2 datasets (demo or upload)."), _matrix_payload(np.zeros((0, 0))), None ) names = [name for _, _, name in datasets] umap_kwargs = dict( n_components=int(umap_n_components), n_neighbors=int(umap_n_neighbors), min_dist=float(umap_min_dist), metric=umap_metric, spread=float(umap_spread), learning_rate=float(umap_learning_rate), n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs), negative_sample_rate=int(umap_negative_sample_rate), init=umap_init, densmap=bool(umap_densmap), set_op_mix_ratio=float(umap_set_op_mix_ratio), local_connectivity=float(umap_local_connectivity), repulsion_strength=float(umap_repulsion_strength), random_state=int(umap_random_state), target_metric="categorical", target_weight=0.5, ) embs, labels_per_ds = _compute_embeddings( framework=framework, datasets=datasets, n_per_dataset=int(n_eval_per_dataset), label_aware=bool(label_aware), umap_cfg={ "mode": umap_mode, "kwargs": umap_kwargs, "repr": channel_representation, "angle_delay_bins": int(angle_delay_bins), } ) D = _compute_distance_matrix( embs=embs, distance_mode=distance_mode, num_projections=int(num_projections), label_aware=bool(label_aware), labels_per_ds=labels_per_ds, label_weighting=label_weighting, label_max_per_class=int(label_max_per_class), ) D_np = D.detach().cpu().numpy() # Normalize distance matrix to [0, 1] range (min-max normalization) d_min = D_np.min() d_max = D_np.max() if d_max > d_min: # Avoid division by zero D_np = (D_np - d_min) / (d_max - d_min) else: D_np = np.zeros_like(D_np) # All values are the same, set to 0 img = _plot_heatmap(D_np, labels=names) return ( gr.update(value="Done ✅"), _matrix_payload(D_np, labels=names), img ) # ------------------------ # UI # ------------------------ with gr.Blocks(title="Dataset Distancing Lab") as demo: gr.Markdown( """ # Dataset Distancing Lab Compute distances between datasets using **RAW**, **UMAP**, or **LWM** embeddings. Upload your `.pt`/`.p` datasets or try the built-in samples under `data/{task}/{scenario}/...`. **Format:** each file should be a Torch file with keys: - `channels`: `Tensor[N, ...]` (complex supported; real+imag will be concatenated) - `labels` (optional): `Tensor[N]` """ ) with gr.Accordion("📚 Citation", open=False): gr.Markdown( """ If you use this lab or methods in your work, please cite: ```bibtex @INPROCEEDINGS{10942657, author={Morais, João and Alikhani, Sadjad and Malhotra, Akshay and Hamidi-Rad, Shahab and Alkhateeb, Ahmed}, booktitle={2024 58th Asilomar Conference on Signals, Systems, and Computers}, title={A Dataset Similarity Evaluation Framework for Wireless Communications and Sensing}, year={2024}, volume={}, number={}, pages={1144-1149}, keywords={Wireless communication;Dimensionality reduction;Adaptation models;Wireless sensor networks;Nearest neighbor methods;Extraterrestrial measurements;Data structures;Distance measurement;Data models;Sensors}, doi={10.1109/IEEECONF60004.2024.10942657}} ``` """ ) with gr.Row(): with gr.Column(scale=1, min_width=320): gr.Markdown("### Framework & Distance") framework = gr.Radio( choices=["RAW", "UMAP", "LWM"], value="RAW", label="Framework", ) distance_mode = gr.Radio( choices=["sliced_wasserstein", "euclidean_centroid", "cosine_similarity"], value="sliced_wasserstein", label="Distance Mode" ) label_aware = gr.Checkbox(value=True, label="Label-aware (supported by SW distance)") label_weighting = gr.Dropdown( choices=["uniform", "support"], value="uniform", label="Label weighting" ) label_max_per_class = gr.Number(value=1e10, precision=0, label="Max samples / class") num_projections = gr.Slider(8, 256, value=64, step=1, label="SW #projections") n_eval_per_dataset = gr.Slider(32, 4096, value=1024, step=32, label="Samples per dataset") gr.Markdown("### UMAP (only if Framework=UMAP)") umap_mode = gr.Dropdown(["unsupervised", "supervised"], value="supervised", label="UMAP Mode") channel_representation = gr.Dropdown(["raw", "angle_delay"], value="raw", label="Channel representation") angle_delay_bins = gr.Slider(4, 128, value=16, step=1, label="Angle-delay bins (if used)") with gr.Accordion("Advanced UMAP settings", open=False): umap_n_components = gr.Slider(2, 256, value=128, step=1, label="n_components") umap_n_neighbors = gr.Slider(2, 128, value=32, step=1, label="n_neighbors") umap_min_dist = gr.Slider(0.0, 0.99, value=0.1, step=0.01, label="min_dist") umap_metric = gr.Dropdown( ["euclidean", "cosine", "manhattan", "chebyshev", "correlation"], value="euclidean", label="metric" ) umap_spread = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="spread") umap_learning_rate = gr.Slider(0.1, 10.0, value=1.0, step=0.1, label="learning_rate") umap_n_epochs = gr.Number(value=0, precision=0, label="n_epochs (0 = auto)") umap_negative_sample_rate = gr.Slider(1, 50, value=5, step=1, label="negative_sample_rate") umap_init = gr.Dropdown(["spectral", "random"], value="spectral", label="init") umap_densmap = gr.Checkbox(value=False, label="densMAP") umap_set_op_mix_ratio = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="set_op_mix_ratio") umap_local_connectivity = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="local_connectivity") umap_repulsion_strength = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="repulsion_strength") umap_random_state = gr.Number(value=42, precision=0, label="random_state") with gr.Column(scale=1, min_width=320): gr.Markdown("### Demo datasets (data/{task}/{scenario}/...)") demo_task = gr.Dropdown(choices=[], value=None, label="Task", interactive=True) demo_select = gr.CheckboxGroup(choices=[], value=[], label="Scenarios (files inside each scenario)") refresh = gr.Button("🔄 Refresh Demo Lists") refresh.click( fn=refresh_demo_tasks, inputs=[], outputs=[demo_task] ) demo_task.change( fn=refresh_demo_scenarios, inputs=[demo_task], outputs=[demo_select] ) gr.Markdown("### Or upload your own") uploads = gr.Files( label="Upload multiple .pt/.p datasets", file_count="multiple", file_types=[".pt", ".p"] ) run_btn = gr.Button("🚀 Compute distances", variant="primary") status = gr.Markdown("") with gr.Column(scale=2): gr.Markdown("### Distance Matrix (Table)") matrix_out = gr.Dataframe( value=None, headers=None, interactive=False, wrap=True, row_count=(0, "dynamic"), col_count=(0, "dynamic"), label="Distances" ) gr.Markdown("### Distance Matrix (Heatmap)") heatmap = gr.Image(type="numpy", interactive=False) run_btn.click( fn=run_compute, inputs=[ framework, distance_mode, label_aware, label_weighting, label_max_per_class, num_projections, n_eval_per_dataset, demo_task, demo_select, uploads, umap_mode, umap_n_components, umap_n_neighbors, umap_min_dist, umap_metric, umap_spread, umap_learning_rate, umap_n_epochs, umap_negative_sample_rate, umap_init, umap_densmap, umap_set_op_mix_ratio, umap_local_connectivity, umap_repulsion_strength, umap_random_state, channel_representation, angle_delay_bins ], outputs=[status, matrix_out, heatmap] ) if __name__ == "__main__": demo.launch()