Spaces:
Sleeping
Sleeping
| import re | |
| import einops | |
| import gradio as gr | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import torch | |
| import torch.nn.functional as F | |
| import torchdiffeq | |
| DESCRIPTION = """ | |
| <div class="head"> | |
| <div class="title">Fast LiDAR Data Generation with Rectified Flows</div> | |
| <div class="conference">ICRA 2025</div> | |
| <div class="authors"> | |
| <a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a><sup>1</sup> | |
| | |
| <a> Xiaowen Liu</a><sup>1</sup> | |
| | |
| <a> Tomoya Miyawaki</a><sup>1</sup> | |
| | |
| <a> Yumi Iwashita</a><sup>2</sup> | |
| | |
| <a> Ryo Kurazume</a><sup>1</sup> | |
| </div> | |
| <div class="affiliations"> | |
| <sup>1</sup>Kyushu University | |
| | |
| <sup>2</sup>NASA Jet Propulsion Laboratory | |
| </div> | |
| <div class="materials"> | |
| <a href="https://kazuto1011.github.io/r2flow">Project</a> | | |
| <a href="https://arxiv.org/abs/2412.02241">Paper</a> | | |
| <a href="https://github.com/kazuto1011/r2flow">Code</a> | |
| </div> | |
| <br> | |
| <div class="description"> | |
| This is a demo of our paper "Fast LiDAR Data Generation with Rectified Flows" accepted to ICRA 2025.<br> | |
| We propose <strong>R2Flow</strong>, a rectified flow-based LiDAR generative model which generate the LiDAR range/reflectance images.<br> | |
| </div> | |
| <br> | |
| </div> | |
| """ | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| torch.set_grad_enabled(False) | |
| torch.backends.cudnn.benchmark = True | |
| device = torch.device(device) | |
| model_dict = { | |
| "1-RF": "r2flow-kitti360-1rf", | |
| "2-RF": "r2flow-kitti360-2rf", | |
| "2-RF + 4-TD": "r2flow-kitti360-2rf-4td", | |
| "2-RF + 2-TD": "r2flow-kitti360-2rf-2td", | |
| "2-RF + 1-TD": "r2flow-kitti360-2rf-1td", | |
| } | |
| torch_hub_kwargs = dict( | |
| repo_or_dir="kazuto1011/r2flow", | |
| model="pretrained_r2flow", | |
| device=device, | |
| show_info=False, | |
| ) | |
| def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo): | |
| colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] | |
| colors = torch.from_numpy(colors).to(tensor) | |
| tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor | |
| ids = (tensor * 256).clamp(0, 255).long() | |
| tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) | |
| tensor = tensor.mul(255).clamp(0, 255).byte() | |
| return tensor | |
| def model_verbose(model, nfe, progress): | |
| handler = progress.tqdm(range(nfe), desc="Generating...") | |
| def _model(t, x): | |
| handler.update(1) | |
| return model(t, x) | |
| return _model | |
| def generate(nfe: int, solver: str, phase: str, progress=gr.Progress()): | |
| model, lidar_utils, _ = torch.hub.load(config=model_dict[phase], **torch_hub_kwargs) | |
| with torch.inference_mode(): | |
| x1 = torchdiffeq.odeint( | |
| func=model_verbose(model, int(nfe), progress), | |
| y0=torch.randn(1, model.in_channels, *model.resolution, device=device), | |
| t=torch.linspace(0, 1, int(nfe) + 1, device=device), | |
| method=solver, | |
| )[-1] | |
| depth = lidar_utils.restore_metric_depth(x1[:, [0]]) | |
| rflct = lidar_utils.denormalize(x1[:, [1]]) | |
| point = lidar_utils.convert_metric_depth(depth, format="cartesian") | |
| z_min, z_max = -2, 0.5 | |
| z = (point[:, [2]] - z_min) / (z_max - z_min) | |
| color = colorize(z.clamp(0, 1), cm.viridis) / 255 | |
| point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy() | |
| color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy() | |
| fig = go.Figure( | |
| data=[ | |
| go.Scatter3d( | |
| x=-point[..., 0], | |
| y=-point[..., 1], | |
| z=point[..., 2], | |
| mode="markers", | |
| marker=dict(size=1, color=color), | |
| ) | |
| ], | |
| layout=dict( | |
| scene=dict( | |
| xaxis=dict(showticklabels=False, visible=False), | |
| yaxis=dict(showticklabels=False, visible=False), | |
| zaxis=dict(showticklabels=False, visible=False), | |
| aspectmode="data", | |
| ), | |
| margin=dict(l=0, r=0, b=0, t=0), | |
| paper_bgcolor="white", | |
| plot_bgcolor="white", | |
| ), | |
| ) | |
| depth = depth / lidar_utils.max_depth | |
| depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() | |
| rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() | |
| model.cpu() | |
| lidar_utils.cpu() | |
| return depth, rflct, fig | |
| def setup_dropdown(value): | |
| if "TD" in value: | |
| solver_choices = ["euler"] | |
| solver_default = "euler" | |
| num_step = re.findall(r"(\d+)-TD", value)[0] | |
| nfe_choices = [num_step] | |
| nfe_default = num_step | |
| else: | |
| solver_choices = ["euler", "dopri5"] | |
| solver_default = "euler" | |
| nfe_choices = [2**i for i in range(0, 9)] | |
| nfe_default = 256 | |
| dropdown_solver = gr.Dropdown( | |
| choices=solver_choices, | |
| value=solver_default, | |
| label="ODE solver", | |
| info="Fixed if TD enabled", | |
| ) | |
| dropdown_nfe = gr.Dropdown( | |
| choices=nfe_choices, | |
| value=nfe_default, | |
| label="Number of sampling steps", | |
| info="Fixed if TD enabled", | |
| ) | |
| return dropdown_solver, dropdown_nfe | |
| with gr.Blocks( | |
| css=""" | |
| .head { | |
| text-align: center; | |
| display: block; | |
| font-size: var(--text-xl); | |
| } | |
| .title { | |
| font-size: var(--text-xxl); | |
| font-weight: bold; | |
| margin-top: 2rem; | |
| } | |
| .description { | |
| font-size: var(--text-lg); | |
| } | |
| """, | |
| theme=gr.themes.Ocean(), | |
| ) as demo: | |
| gr.HTML(DESCRIPTION) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| gr.Textbox(device, label="Running device") | |
| dropdown_model = gr.Dropdown( | |
| choices=list(model_dict.keys()), | |
| value="2-RF + 4-TD", | |
| label="Model checkpoint", | |
| info="RF: rectified flow, TD: timestep distillation", | |
| ) | |
| dropdown_solver, dropdown_nfe = setup_dropdown(dropdown_model.value) | |
| dropdown_model.change( | |
| setup_dropdown, | |
| inputs=[dropdown_model], | |
| outputs=[dropdown_solver, dropdown_nfe], | |
| ) | |
| btn = gr.Button(value="Generate", variant="primary") | |
| with gr.Column(): | |
| range_view = gr.Image(type="numpy", label="Range image") | |
| rflct_view = gr.Image(type="numpy", label="Reflectance image") | |
| point_view = gr.Plot(label="Point cloud") | |
| btn.click( | |
| generate, | |
| inputs=[dropdown_nfe, dropdown_solver, dropdown_model], | |
| outputs=[range_view, rflct_view, point_view], | |
| ) | |
| demo.queue() | |
| demo.launch() | |