Spaces:
Running
Running
Bachmann Roman Christian
commited on
Commit
·
a6ebf2a
1
Parent(s):
3b49518
app.py
CHANGED
|
@@ -26,6 +26,7 @@ from mpl_toolkits.axes_grid1 import ImageGrid
|
|
| 26 |
from tqdm import tqdm
|
| 27 |
import random
|
| 28 |
from functools import partial
|
|
|
|
| 29 |
|
| 30 |
# import some common detectron2 utilities
|
| 31 |
from detectron2 import model_zoo
|
|
@@ -290,7 +291,7 @@ def plot_predictions(input_dict, preds, masks, image_size=224):
|
|
| 290 |
plt.close()
|
| 291 |
|
| 292 |
|
| 293 |
-
def inference(img, num_rgb, num_depth, num_semseg, seed
|
| 294 |
im = Image.open(img)
|
| 295 |
|
| 296 |
# Center crop and resize RGB
|
|
@@ -324,21 +325,22 @@ def inference(img, num_rgb, num_depth, num_semseg, seed, perform_sampling, alpha
|
|
| 324 |
input_dict = {k: v.to(device) for k,v in input_dict.items()}
|
| 325 |
|
| 326 |
|
| 327 |
-
torch.manual_seed(int(seed)) # change seed to resample new mask
|
| 328 |
-
|
| 329 |
if perform_sampling:
|
| 330 |
# Randomly sample masks
|
| 331 |
|
| 332 |
-
|
| 333 |
|
| 334 |
preds, masks = multimae.forward(
|
| 335 |
input_dict,
|
| 336 |
mask_inputs=True, # True if forward pass should sample random masks
|
| 337 |
num_encoded_tokens=num_tokens,
|
| 338 |
-
alphas=
|
| 339 |
)
|
| 340 |
else:
|
| 341 |
# Randomly sample masks using the specified number of tokens per modality
|
|
|
|
|
|
|
|
|
|
| 342 |
task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
|
| 343 |
selected_rgb_idxs = torch.randperm(196)[:num_rgb]
|
| 344 |
selected_depth_idxs = torch.randperm(196)[:num_depth]
|
|
@@ -365,7 +367,7 @@ title = "MultiMAE"
|
|
| 365 |
description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
|
| 366 |
Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
|
| 367 |
Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
|
| 368 |
-
Choose the number of visible tokens using the sliders below
|
| 369 |
|
| 370 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
|
| 371 |
target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
|
|
@@ -375,24 +377,18 @@ css = '.output-image{height: 713px !important}'
|
|
| 375 |
|
| 376 |
# Example images
|
| 377 |
os.system("wget https://i.imgur.com/c9ObJdK.jpg")
|
| 378 |
-
examples = [['c9ObJdK.jpg',
|
| 379 |
|
| 380 |
gr.Interface(
|
| 381 |
fn=inference,
|
| 382 |
inputs=[
|
| 383 |
gr.inputs.Image(label='RGB input image', type='filepath'),
|
|
|
|
|
|
|
| 384 |
gr.inputs.Slider(label='Number of RGB input tokens', default=32, step=1, minimum=0, maximum=196),
|
| 385 |
gr.inputs.Slider(label='Number of depth input tokens', default=32, step=1, minimum=0, maximum=196),
|
| 386 |
gr.inputs.Slider(label='Number of semantic input tokens', default=32, step=1, minimum=0, maximum=196),
|
| 387 |
gr.inputs.Number(label='Random seed: Change this to sample different masks', default=0),
|
| 388 |
-
gr.inputs.Checkbox(label='Randomize the number of tokens: Check this to ignore the above sliders and randomly sample the number \
|
| 389 |
-
of tokens per modality using the parameters below', default=False),
|
| 390 |
-
gr.inputs.Slider(label='Symmetric Dirichlet concentration parameter (α > 0). Low values (α << 1.0) result in a sampling behavior, \
|
| 391 |
-
where most of the time, all visible tokens will be sampled from a single modality. High values \
|
| 392 |
-
(α >> 1.0) result in similar numbers of tokens being sampled for each modality. α = 1.0 is equivalent \
|
| 393 |
-
to uniform sampling over the simplex and contains both previous cases and everything in between.',
|
| 394 |
-
default=1.0, step=0.1, minimum=0.1, maximum=5.0),
|
| 395 |
-
gr.inputs.Slider(label='Number of input tokens', default=98, step=1, minimum=0, maximum=588),
|
| 396 |
],
|
| 397 |
outputs=[
|
| 398 |
gr.outputs.Image(label='MultiMAE predictions', type='file')
|
|
|
|
| 26 |
from tqdm import tqdm
|
| 27 |
import random
|
| 28 |
from functools import partial
|
| 29 |
+
import time
|
| 30 |
|
| 31 |
# import some common detectron2 utilities
|
| 32 |
from detectron2 import model_zoo
|
|
|
|
| 291 |
plt.close()
|
| 292 |
|
| 293 |
|
| 294 |
+
def inference(img, num_tokens, perform_sampling, num_rgb, num_depth, num_semseg, seed):
|
| 295 |
im = Image.open(img)
|
| 296 |
|
| 297 |
# Center crop and resize RGB
|
|
|
|
| 325 |
input_dict = {k: v.to(device) for k,v in input_dict.items()}
|
| 326 |
|
| 327 |
|
|
|
|
|
|
|
| 328 |
if perform_sampling:
|
| 329 |
# Randomly sample masks
|
| 330 |
|
| 331 |
+
torch.manual_seed(int(time.time())) # Random mode is random
|
| 332 |
|
| 333 |
preds, masks = multimae.forward(
|
| 334 |
input_dict,
|
| 335 |
mask_inputs=True, # True if forward pass should sample random masks
|
| 336 |
num_encoded_tokens=num_tokens,
|
| 337 |
+
alphas=1.0
|
| 338 |
)
|
| 339 |
else:
|
| 340 |
# Randomly sample masks using the specified number of tokens per modality
|
| 341 |
+
|
| 342 |
+
torch.manual_seed(int(seed)) # change seed to resample new mask
|
| 343 |
+
|
| 344 |
task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
|
| 345 |
selected_rgb_idxs = torch.randperm(196)[:num_rgb]
|
| 346 |
selected_depth_idxs = torch.randperm(196)[:num_depth]
|
|
|
|
| 367 |
description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
|
| 368 |
Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
|
| 369 |
Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
|
| 370 |
+
Choose the number of visible tokens using the sliders below and see how MultiMAE reconstructs the modalities!"
|
| 371 |
|
| 372 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
|
| 373 |
target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
|
|
|
|
| 377 |
|
| 378 |
# Example images
|
| 379 |
os.system("wget https://i.imgur.com/c9ObJdK.jpg")
|
| 380 |
+
examples = [['c9ObJdK.jpg', 98, False, 32, 32, 32, 0]]
|
| 381 |
|
| 382 |
gr.Interface(
|
| 383 |
fn=inference,
|
| 384 |
inputs=[
|
| 385 |
gr.inputs.Image(label='RGB input image', type='filepath'),
|
| 386 |
+
gr.inputs.Slider(label='Number of input tokens', default=98, step=1, minimum=0, maximum=588),
|
| 387 |
+
gr.inputs.Checkbox(label='Manual mode: Check this to manually set the number of input tokens per modality using the sliders below', default=False),
|
| 388 |
gr.inputs.Slider(label='Number of RGB input tokens', default=32, step=1, minimum=0, maximum=196),
|
| 389 |
gr.inputs.Slider(label='Number of depth input tokens', default=32, step=1, minimum=0, maximum=196),
|
| 390 |
gr.inputs.Slider(label='Number of semantic input tokens', default=32, step=1, minimum=0, maximum=196),
|
| 391 |
gr.inputs.Number(label='Random seed: Change this to sample different masks', default=0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
],
|
| 393 |
outputs=[
|
| 394 |
gr.outputs.Image(label='MultiMAE predictions', type='file')
|