diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..19f846cfea6e8d0869d82423a206db6e52aaa6ff 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..831669619f694d7dc1d9a82ed9ae0c235decb783 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +cache/ +output/ \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e2952c5879bd150fda1f431e62f125ea35464a4d --- /dev/null +++ b/app.py @@ -0,0 +1,165 @@ +import os +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ( + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed +) +import torch +from contextlib import nullcontext +import trimesh +import gradio as gr +from gradio_imageslider import ImageSlider +from da2.utils.base import load_config +from da2.utils.model import load_model +from da2.utils.io import ( + read_cv2_image, + torch_transform, + tensorize +) +from da2.utils.vis import colorize_distance +from da2.utils.d2pc import distance2pointcloud +from datetime import ( + timedelta, + datetime +) +import cv2 +import numpy as np + +last_glb_path = None + +def prepare_to_run_demo(): + config = load_config('configs/infer.json') + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout'])) + output_dir = f'output/infer' + if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) + accu_steps = config['accelerator']['accumulation_nsteps'] + accelerator = Accelerator( + gradient_accumulation_steps=accu_steps, + mixed_precision=config['accelerator']['mixed_precision'], + log_with=config['accelerator']['report_to'], + project_config=ProjectConfiguration(project_dir=output_dir), + kwargs_handlers=[kwargs] + ) + logger = get_logger(__name__, log_level='INFO') + config['env']['logger'] = logger + set_seed(config['env']['seed']) + return config, accelerator + +def read_mask_demo(mask_path, shape): + if mask_path is None: + return np.ones(shape[1:]) > 0 + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask = mask > 0 + return mask + +def load_infer_data_demo(image, mask, model_dtype, device): + cv2_image = read_cv2_image(image) + image = torch_transform(cv2_image) + mask = read_mask_demo(mask, image.shape) + image = tensorize(image, model_dtype, device) + return image, cv2_image, mask + +def ply2glb(ply_path, glb_path): + pcd = trimesh.load(ply_path) + points = np.asarray(pcd.vertices) + colors = np.asarray(pcd.visual.vertex_colors) + cloud = trimesh.points.PointCloud(vertices=points, colors=colors) + cloud.export(glb_path) + os.remove(ply_path) + +def fn(image_path, mask_path): + global last_glb_path + config, accelerator = prepare_to_run_demo() + model = load_model(config, accelerator) + image, cv2_image, mask = load_infer_data_demo(image_path, mask_path, + model_dtype=config['spherevit']['dtype'], device=accelerator.device) + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + with autocast_ctx, torch.no_grad(): + distance = model(image).cpu().numpy()[0] + if last_glb_path is not None: + os.remove(last_glb_path) + distance_vis = colorize_distance(distance, mask) + save_path = f'cache/tmp_{datetime.now().strftime("%Y%m%d_%H%M%S")}.glb' + last_glb_path = save_path + normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False) + ply2glb(save_path.replace('.glb', '.ply'), save_path) + return save_path, [distance_vis, normal_image] + +inputs = [ + gr.Image(label="Input Image", type="filepath"), + gr.Image(label="Input Mask", type="filepath"), +] +outputs = [ + gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Point Cloud"), + gr.ImageSlider( + label="Output Depth / Normal (transformed from the depth)", + type="pil", + slider_position=75, + ) +] + +demo = gr.Interface( + fn=fn, + title="DA2: Depth Anything in Any Direction", + description=""" +

+ + + + + + + + badge-github-stars + + + social + + + social + +
+ Please consider starring our GitHub Repo if you find this demo useful! +

+

Note: the "Input Mask" is optional, all pixels are assumed to be valid if mask is None.

+ """, + inputs=inputs, + outputs=outputs, + examples=[ + [os.path.join(os.path.dirname(__file__), "assets/demos/a1.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a2.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a3.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a4.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/b0.png"), + os.path.join(os.path.dirname(__file__), "assets/masks/b0.png")], + [os.path.join(os.path.dirname(__file__), "assets/demos/b1.png"), + os.path.join(os.path.dirname(__file__), "assets/masks/b1.png")], + [os.path.join(os.path.dirname(__file__), "assets/demos/a5.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a6.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a7.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a8.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/b2.png"), + os.path.join(os.path.dirname(__file__), "assets/masks/b2.png")], + [os.path.join(os.path.dirname(__file__), "assets/demos/b3.png"), + os.path.join(os.path.dirname(__file__), "assets/masks/b3.png")], + [os.path.join(os.path.dirname(__file__), "assets/demos/a9.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a10.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a11.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/a0.png"), None], + [os.path.join(os.path.dirname(__file__), "assets/demos/b4.png"), + os.path.join(os.path.dirname(__file__), "assets/masks/b4.png")], + [os.path.join(os.path.dirname(__file__), "assets/demos/b5.png"), + os.path.join(os.path.dirname(__file__), "assets/masks/b5.png")], + ], + examples_per_page=20 +) + +demo.launch( + server_name="0.0.0.0", + server_port=6381, +) diff --git a/assets/badges/icon2.png b/assets/badges/icon2.png new file mode 100644 index 0000000000000000000000000000000000000000..847382c831464926fb2b7170b5868930ea86fc34 --- /dev/null +++ b/assets/badges/icon2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d254fc5009dd41b367790aa9e45f05770b81ed62c67d8cc713bee4608567218f +size 6765 diff --git a/assets/badges/teaser.jpg b/assets/badges/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..395a804413725d527c02a1b8bba4935a98812c18 --- /dev/null +++ b/assets/badges/teaser.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c6786218d0a17115e6ed1320434b2b47101290a7e244f2eed1ebe70e4822464 +size 1201407 diff --git a/assets/demos/a0.png b/assets/demos/a0.png new file mode 100644 index 0000000000000000000000000000000000000000..267653c09acf2f8a26be91d6fbd4c011eda89acd --- /dev/null +++ b/assets/demos/a0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eedc66f98cf0a949602f691c3eed51511ae520cf8f63674abe542741ba6090b8 +size 744091 diff --git a/assets/demos/a1.png b/assets/demos/a1.png new file mode 100644 index 0000000000000000000000000000000000000000..815eadf96f883a47a2b4fdf171d5cbef8d0c9017 --- /dev/null +++ b/assets/demos/a1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:906f336ab4c6561ee85b9cb883a6aa34cf11289fc86b6a4e4382baed56981aa7 +size 821703 diff --git a/assets/demos/a10.png b/assets/demos/a10.png new file mode 100644 index 0000000000000000000000000000000000000000..da37cd6ddab7133ccb9ce12b2d2660bc8228915e --- /dev/null +++ b/assets/demos/a10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6d058aef9322964f5d36de90ab91470e283acab248604bcd488a43c680a9e7d +size 881818 diff --git a/assets/demos/a11.png b/assets/demos/a11.png new file mode 100644 index 0000000000000000000000000000000000000000..12620f1f8717c787c482c0b0bbc40c2794d3ae48 --- /dev/null +++ b/assets/demos/a11.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45af8c71b8d44880503b5da1b5f67a0d5638860b9f9149cae7d16a3a3975d090 +size 848394 diff --git a/assets/demos/a2.png b/assets/demos/a2.png new file mode 100644 index 0000000000000000000000000000000000000000..2a25f8f6fb8b191147612ef4b114aad93802f436 --- /dev/null +++ b/assets/demos/a2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fa931d70c6220cec0b56a9cdf651f12fa35436d937cd2cf481d10dddb2a114e +size 809628 diff --git a/assets/demos/a3.png b/assets/demos/a3.png new file mode 100644 index 0000000000000000000000000000000000000000..5cfb653c14bc04496ec99a06f0309e922f9ff8df --- /dev/null +++ b/assets/demos/a3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a85573ac5d51a261d82b23475488e769bd9b3e392948e60e6dc73f0c7ace762b +size 854468 diff --git a/assets/demos/a4.png b/assets/demos/a4.png new file mode 100644 index 0000000000000000000000000000000000000000..c7d9899b683013c1277dc1bf274e59c106259fbc --- /dev/null +++ b/assets/demos/a4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0a544ec4b542c59f1fbfaf99f86eb60b4c0dbce7c8e4b1bac9e6e23e889c7ec +size 812626 diff --git a/assets/demos/a5.png b/assets/demos/a5.png new file mode 100644 index 0000000000000000000000000000000000000000..716e2c9ad54f622fc73d622dc269db9d35f33915 --- /dev/null +++ b/assets/demos/a5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e36ed78b74223eae24f8c85f1cdab00d1a3a5b494fec807240cb7d3427fad87 +size 847578 diff --git a/assets/demos/a6.png b/assets/demos/a6.png new file mode 100644 index 0000000000000000000000000000000000000000..4938bd13ce9be9ba95427b707c644ca6edbb3a3e --- /dev/null +++ b/assets/demos/a6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e48031fcd3e5a84e4ea4513a23e2ec8150f8ec3fbdae1d4b2d51fc67ac588fe6 +size 818477 diff --git a/assets/demos/a7.png b/assets/demos/a7.png new file mode 100644 index 0000000000000000000000000000000000000000..f94efe1f72d5c0422246f5d236617f0d6c69262a --- /dev/null +++ b/assets/demos/a7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12b99fdddea8eefb6885114bd386fc4fad0484e13c85c88364a43396f9cef3f9 +size 904680 diff --git a/assets/demos/a8.png b/assets/demos/a8.png new file mode 100644 index 0000000000000000000000000000000000000000..bb1f55614094f39a16c40221801ccbeafc9a506f --- /dev/null +++ b/assets/demos/a8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b29df5b6294742acc43d8ce41073b335e98024459273b77d9b943fd3583ac35 +size 784328 diff --git a/assets/demos/a9.png b/assets/demos/a9.png new file mode 100644 index 0000000000000000000000000000000000000000..faec9983f223633853a0eedd5d4399f0ff108f76 --- /dev/null +++ b/assets/demos/a9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba92bf3adf1d1b2a775d5b0f895a16876159fc1a43d98328c923fdc994d6e346 +size 910249 diff --git a/assets/demos/b0.png b/assets/demos/b0.png new file mode 100644 index 0000000000000000000000000000000000000000..c6ec9dd9975fe81f94760b38ca84eab425ab1f6b --- /dev/null +++ b/assets/demos/b0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b610ae826372778853553810ef0e07e4f91d8507549dc0f5f32eca038348a37 +size 850392 diff --git a/assets/demos/b1.png b/assets/demos/b1.png new file mode 100644 index 0000000000000000000000000000000000000000..27d7c6e8ec92cf3a476966d65e1325c2230f1964 --- /dev/null +++ b/assets/demos/b1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2df3207be859cf8524e9a00a76efb606e626ca4cc9dbd81178fe24de43a6b97b +size 798128 diff --git a/assets/demos/b2.png b/assets/demos/b2.png new file mode 100644 index 0000000000000000000000000000000000000000..4beb346587164c6867c2859d848106aa5f02522d --- /dev/null +++ b/assets/demos/b2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:790218133cd507f1f9ca65fcdff60f74325df39ebd0df1d5b6e6261a8dfd29a8 +size 863217 diff --git a/assets/demos/b3.png b/assets/demos/b3.png new file mode 100644 index 0000000000000000000000000000000000000000..3acd2114400b10a557021a5904e906cf00b8942b --- /dev/null +++ b/assets/demos/b3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:843b680077e114451285efc6536e811739cbbab07ade423459a5bc24e747455f +size 650671 diff --git a/assets/demos/b4.png b/assets/demos/b4.png new file mode 100644 index 0000000000000000000000000000000000000000..6ba15860a680613beb8ea710414407470f26a6e8 --- /dev/null +++ b/assets/demos/b4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5615e49fa1bea5ee049a66bbe577d48dd63f441e86a4ae5b225136e7e2295187 +size 804398 diff --git a/assets/demos/b5.png b/assets/demos/b5.png new file mode 100644 index 0000000000000000000000000000000000000000..cccf4ea0427459d11d78447e7ece5a2dd89304e0 --- /dev/null +++ b/assets/demos/b5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7957ee9e54dd6b61b74014412ece3de7bbe999ae0c0be41c4d762d62d8352656 +size 669137 diff --git a/assets/masks/b0.png b/assets/masks/b0.png new file mode 100644 index 0000000000000000000000000000000000000000..3e6417f05a4b4dcebaab35d11352e5fa2e257430 --- /dev/null +++ b/assets/masks/b0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7495c6c7672f1b0551f5640a0344a3730744cfa535697307afa917fbf46466ad +size 6993 diff --git a/assets/masks/b1.png b/assets/masks/b1.png new file mode 100644 index 0000000000000000000000000000000000000000..c7d1fd480be44bfd86c4002186f9cb0c69eb4cda --- /dev/null +++ b/assets/masks/b1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1aea3b6a9a99adbcdb71fcbc9eb5c5f18fbdc36b38829d7ba972183a7ec564e3 +size 5357 diff --git a/assets/masks/b2.png b/assets/masks/b2.png new file mode 100644 index 0000000000000000000000000000000000000000..5a48bd5586e7dd2e05a4b10e093a7ab8118837b5 --- /dev/null +++ b/assets/masks/b2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4360d8523cb2309b29ed549c6a7c84dd0d6a3ca5f55720ae43b728668dfe6c9b +size 7703 diff --git a/assets/masks/b3.png b/assets/masks/b3.png new file mode 100644 index 0000000000000000000000000000000000000000..7a332ee4df7d8f4420fe536dedd5e46884b00667 --- /dev/null +++ b/assets/masks/b3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1e6f1d40d8f9e8e5593bf3f5fe67967528b8afcbfaf605658f19004edbdb10d +size 4568 diff --git a/assets/masks/b4.png b/assets/masks/b4.png new file mode 100644 index 0000000000000000000000000000000000000000..8d63e68b058bd3594d2f464bdad92a9bb4c6ee2e --- /dev/null +++ b/assets/masks/b4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a2a1018ad95749d83193fc0f333e1af04de119857e2564c5fbefa41301f2226 +size 5428 diff --git a/assets/masks/b5.png b/assets/masks/b5.png new file mode 100644 index 0000000000000000000000000000000000000000..acf9f30cf3d75c007ba3e1fe04e84a319d9ca7c8 --- /dev/null +++ b/assets/masks/b5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c38cca29eec4baaeb7b765f595f28d13e1fcaf7707bed7ad83277b12eee1f504 +size 4883 diff --git a/configs/accelerate/0.yaml b/configs/accelerate/0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f7938e128402b64e2690255025224e303654b52 --- /dev/null +++ b/configs/accelerate/0.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +gpu_ids: '0' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/configs/infer.json b/configs/infer.json new file mode 100755 index 0000000000000000000000000000000000000000..55ecd9accb3d2562cb1a0c5e236bb49bcf0f1c65 --- /dev/null +++ b/configs/infer.json @@ -0,0 +1,39 @@ +{ + "env": { + "seed": 42, + "verbose": true + }, + "accelerator": { + "report_to": ["tensorboard"], + "mixed_precision": "fp16", + "accumulation_nsteps": 4, + "timeout": 36000 + }, + "inference": { + "images": "assets/demos", + "masks": "assets/masks", + "min_pixels": 580000, + "max_pixels": 620000 + }, + "spherevit": { + "vit_w_esphere": { + "input_dims": [1024, 1024, 1024, 1024], + "hidden_dim": 512, + "num_heads": 8, + "expansion": 4, + "num_layers_head": [2, 2, 2], + "dropout": 0.0, + "layer_scale": 0.0001, + "out_dim": 64, + "kernel_size": 3, + "num_prompt_blocks": 1, + "use_norm": false + }, + "sphere": { + "width": 1092, + "height": 546, + "hfov": 6.2832, + "vfov": 3.1416 + } + } +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3cceff4cade5de6cbc69eed216433351c475866b --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +pip install -e src/ \ No newline at end of file diff --git a/src/da2.egg-info/PKG-INFO b/src/da2.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..8a628b59766e52f98401f3e400f59e7b61280ecd --- /dev/null +++ b/src/da2.egg-info/PKG-INFO @@ -0,0 +1,23 @@ +Metadata-Version: 2.4 +Name: da2 +Version: 0.1.0 +Summary: For the implementation of DA^2: Depth Anything in Any Direction +Author-email: "H. Li" +Requires-Dist: torch==2.5.0 +Requires-Dist: torchvision==0.20.0 +Requires-Dist: torchaudio==2.5.0 +Requires-Dist: xformers==0.0.28.post2 +Requires-Dist: diffusers==0.32.0 +Requires-Dist: tensorboard==2.18.0 +Requires-Dist: utils3d@ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900 +Requires-Dist: opencv-python==4.12.0.88 +Requires-Dist: gradio==5.49.0 +Requires-Dist: gradio-client==1.13.3 +Requires-Dist: gradio-imageslider==0.0.20 +Requires-Dist: accelerate==1.1.1 +Requires-Dist: omegaconf==2.3.0 +Requires-Dist: tabulate==0.9.0 +Requires-Dist: einops==0.8.0 +Requires-Dist: timm==1.0.15 +Requires-Dist: trimesh==4.5.2 +Requires-Dist: transformers==4.46.3 diff --git a/src/da2.egg-info/SOURCES.txt b/src/da2.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..d75e99d41d9df2936757319c33ca8f10a772280e --- /dev/null +++ b/src/da2.egg-info/SOURCES.txt @@ -0,0 +1,28 @@ +pyproject.toml +da2/__init__.py +da2.egg-info/PKG-INFO +da2.egg-info/SOURCES.txt +da2.egg-info/dependency_links.txt +da2.egg-info/requires.txt +da2.egg-info/top_level.txt +da2/model/__init__.py +da2/model/base.py +da2/model/sphere.py +da2/model/spherevit.py +da2/model/vit_w_esphere.py +da2/model/dinov2/__init__.py +da2/model/dinov2/attention.py +da2/model/dinov2/block.py +da2/model/dinov2/dino_head.py +da2/model/dinov2/dinovit.py +da2/model/dinov2/drop_path.py +da2/model/dinov2/layer_scale.py +da2/model/dinov2/mlp.py +da2/model/dinov2/patch_embed.py +da2/model/dinov2/swiglu_ffn.py +da2/utils/__init__.py +da2/utils/base.py +da2/utils/d2pc.py +da2/utils/io.py +da2/utils/model.py +da2/utils/vis.py \ No newline at end of file diff --git a/src/da2.egg-info/dependency_links.txt b/src/da2.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/da2.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/da2.egg-info/requires.txt b/src/da2.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..313b5fd2bed588a94da9f04276a28a465976801b --- /dev/null +++ b/src/da2.egg-info/requires.txt @@ -0,0 +1,18 @@ +torch==2.5.0 +torchvision==0.20.0 +torchaudio==2.5.0 +xformers==0.0.28.post2 +diffusers==0.32.0 +tensorboard==2.18.0 +utils3d@ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900 +opencv-python==4.12.0.88 +gradio==5.49.0 +gradio-client==1.13.3 +gradio-imageslider==0.0.20 +accelerate==1.1.1 +omegaconf==2.3.0 +tabulate==0.9.0 +einops==0.8.0 +timm==1.0.15 +trimesh==4.5.2 +transformers==4.46.3 diff --git a/src/da2.egg-info/top_level.txt b/src/da2.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9f3aca6d46fc8bd68eef426cddaa20a40de4253 --- /dev/null +++ b/src/da2.egg-info/top_level.txt @@ -0,0 +1 @@ +da2 diff --git a/src/da2/__init__.py b/src/da2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb1dddc3c657f4ce1b627593983469f7fbfd94b --- /dev/null +++ b/src/da2/__init__.py @@ -0,0 +1,25 @@ +from .utils.base import ( + prepare_to_run +) +from .utils.model import ( + load_model +) +from .utils.io import ( + load_infer_data +) +from .utils.vis import ( + colorize_distance, + concatenate_images +) +from .utils.d2pc import ( + distance2pointcloud +) + +__all__ = [ + 'prepare_to_run', + 'load_model', + 'load_infer_data', + 'colorize_distance', + 'concatenate_images', + 'distance2pointcloud' +] diff --git a/src/da2/__pycache__/__init__.cpython-312.pyc b/src/da2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82b6c82f97676a7dfc14c2e303e9a83e818ec09e Binary files /dev/null and b/src/da2/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/da2/model/__init__.py b/src/da2/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df880766b1a941dbe4e413c65aae041e5980c3e9 --- /dev/null +++ b/src/da2/model/__init__.py @@ -0,0 +1,11 @@ +from .spherevit import ( + SphereViT +) +from .vit_w_esphere import ( + ViT_w_Esphere +) + +__all__ = [ + 'SphereViT', + 'ViT_w_Esphere', +] diff --git a/src/da2/model/__pycache__/__init__.cpython-312.pyc b/src/da2/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e49302a36e05a435a3381fa203f125e421359d48 Binary files /dev/null and b/src/da2/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/da2/model/__pycache__/base.cpython-312.pyc b/src/da2/model/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5ab12b142a2a78e3a7c15ed35f7a2288b78f03d Binary files /dev/null and b/src/da2/model/__pycache__/base.cpython-312.pyc differ diff --git a/src/da2/model/__pycache__/sphere.cpython-312.pyc b/src/da2/model/__pycache__/sphere.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24317b459f977136fdf0601bc0cd61787ff132c8 Binary files /dev/null and b/src/da2/model/__pycache__/sphere.cpython-312.pyc differ diff --git a/src/da2/model/__pycache__/spherevit.cpython-312.pyc b/src/da2/model/__pycache__/spherevit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baa794d496bf1f6153affb58840f9ac2709feeb7 Binary files /dev/null and b/src/da2/model/__pycache__/spherevit.cpython-312.pyc differ diff --git a/src/da2/model/__pycache__/vit_w_esphere.cpython-312.pyc b/src/da2/model/__pycache__/vit_w_esphere.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6743cfac8ca352235bd0ef9e1ec44dd81ce58d54 Binary files /dev/null and b/src/da2/model/__pycache__/vit_w_esphere.cpython-312.pyc differ diff --git a/src/da2/model/base.py b/src/da2/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..03d467f3184133b590f76af9ff453386719d0736 --- /dev/null +++ b/src/da2/model/base.py @@ -0,0 +1,393 @@ +import torch +import torch.nn as nn +from math import log2, pi +from typing import Tuple +import torch.nn.functional as F +from einops import rearrange +from functools import partial + + +def fourier_dimension_expansion( + x: torch.Tensor, + dim: int = 512, + max_freq: int = 64, + use_cos: bool = True, + use_log: bool = True, +): + device, dtype, input_dim = x.device, x.dtype, x.shape[-1] + # input_dim: 2 + num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim + # num_bands = 512 // 2 = 256 + if use_log: + scales = 2.0 ** torch.linspace( + 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype + ) + else: + scales = torch.linspace( + 1.0, max_freq / 2, num_bands, device=device, dtype=dtype + ) + x = x.unsqueeze(-1) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + x = x * scales * pi + x = torch.cat( + ( + [x.sin(), x.cos()] + if use_cos + else [ + x.sin(), + ] + ), + dim=-1, + ) + x = x.flatten(-2) + return x + +def flatten( + flat_tensor: torch.Tensor, + old: Tuple[int, int], + new: Tuple[int, int], +) -> torch.Tensor: + if old[0] == new[0] and old[1] == new[1]: + return flat_tensor + tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute( + 0, 3, 1, 2 + ) # b c h w + tensor_interp = F.interpolate( + tensor, + size=(new[0], new[1]), + mode='nearest', + ) + flat_tensor_interp = tensor_interp.view( + flat_tensor.shape[0], -1, new[0] * new[1] + ).permute( + 0, 2, 1 + ) # b (h w) c + return flat_tensor_interp.contiguous() + + +class DimensionAligner(nn.Module): + def __init__(self, input_dims: list[int], hidden_dim: int): + super().__init__() + self.aligners = nn.ModuleList([]) + self.num_chunks = len(input_dims) + self.checkpoint = True + for input_dim in input_dims: + self.aligners.append(nn.Linear(input_dim, hidden_dim)) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + outs = [self.aligners[i](x) for i, x in enumerate(xs)] + return outs + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float | torch.Tensor = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +def exists(val): + return val is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class SwiGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.silu(gates) + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + expansion: int = 4, + dropout: float = 0.0, + gated: bool = False, + output_dim: int | None = None, + ): + super().__init__() + if gated: + expansion = int(expansion * 2 / 3) + hidden_dim = int(input_dim * expansion) + output_dim = default(output_dim, input_dim) + self.norm = nn.LayerNorm(input_dim) + self.proj1 = nn.Linear(input_dim, hidden_dim) + self.proj2 = nn.Linear(hidden_dim, output_dim) + self.act = nn.GELU() if not gated else SwiGLU() + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.proj1(x) + x = self.act(x) + x = self.proj2(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + detach_query: bool = False, + residual_ls: bool = False, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = dim if context_dim is None else context_dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv = nn.Linear(context_dim, dim * 2, bias=False) + self.q = nn.Linear(dim, dim, bias=False) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim, bias=False) + self.ls1_1 = ( + LayerScale(dim, layer_scale) + if layer_scale > 0.0 and not residual_ls + else nn.Identity() + ) + self.ls1_2 = ( + LayerScale(dim, layer_scale) + if layer_scale > 0.0 and residual_ls + else nn.Identity() + ) + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.detach_query = detach_query + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.detach_query: + x = x.detach() + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), 'b n (kv h d) -> b h n d kv', h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), 'b n (h d) -> b h n d', h=self.num_heads) + + if rope is not None: + q = rope(q.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3) + k = rope(k.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, 'b n (h d) -> b h n d', h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, 'b n (h d) -> b h n d', h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, 'b h n d -> b n (h d)') + x = self.out(x) + return x + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + attn_bias: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = self.ls1_1( + self.attn( + x, + rope=rope, + rope_pos=rope_pos, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + self.ls1_2(x) + x = self.ls2(self.mlp(x)) + x + return x + + +class AttentionSeq(nn.Module): + def __init__( + self, + num_blocks: int, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + detach_query: bool = False, + residual_ls: bool = False, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + AttentionBlock( + dim=dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + cosine=cosine, + gated=gated, + layer_scale=layer_scale, + context_dim=context_dim, + detach_query=detach_query, + residual_ls=residual_ls, + ) + for _ in range(num_blocks) + ] + ) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + attn_bias: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + for layer in self.layers: + x = layer( + x, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + attn_bias=attn_bias, + rope=rope, + rope_pos=rope_pos, + ) + return x + + +class ResidualConvNet(nn.Module): + def __init__( + self, + dim, + kernel_size: int = 3, + padding_mode: str = 'zeros', + dilation: int = 1, + layer_scale: float = 1.0, + use_norm: bool = False, + ): + super().__init__() + self.conv1 = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation, + padding_mode=padding_mode, + ) + self.conv2 = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation, + padding_mode=padding_mode, + ) + self.activation = nn.LeakyReLU() + self.gamma = ( + nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1)) + if layer_scale > 0.0 + else 1.0 + ) + self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() + self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() + + def forward(self, x): + out = self.activation(x) + out = self.conv1(out) + out = self.norm1(out) + out = self.activation(out) + out = self.conv2(out) + out = self.norm2(out) + return self.gamma * out + x + + +class ResidualUpsampler(nn.Module): + def __init__( + self, + hidden_dim, + output_dim: int = None, + num_layers: int = 2, + kernel_size: int = 3, + layer_scale: float = 1.0, + padding_mode: str = 'zeros', + use_norm: bool = False, + **kwargs, + ): + super().__init__() + output_dim = output_dim if output_dim is not None else hidden_dim // 2 + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + ResidualConvNet( + hidden_dim, + kernel_size=kernel_size, + layer_scale=layer_scale, + padding_mode=padding_mode, + use_norm=use_norm, + ) + ) + self.up = nn.Sequential( + nn.Conv2d( + hidden_dim, + output_dim, + kernel_size=1, + padding=0, + padding_mode=padding_mode, + ), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + ) + + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + return x diff --git a/src/da2/model/dinov2/__init__.py b/src/da2/model/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89f2115deeced559d02bd871f98ec86e9ec6d97c --- /dev/null +++ b/src/da2/model/dinov2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dinovit import ( + DINOViT +) + +__all__ = [ + 'DINOViT' +] diff --git a/src/da2/model/dinov2/__pycache__/__init__.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a3b0109aacc0704664188faedff5938ebca3493 Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/attention.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f27213c6558f754ed1691024894be5b0f2a0377f Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/attention.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/block.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/block.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..834d0c242bf193fbf4bb0b091c3b51518c67894d Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/block.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843708a72c84d53570064acaf2b195930f6c5500 Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c184bb73a9b53af7d9db2636b939be7df22857c Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f949b8580fe05d570db78412686b7ff47df7501a Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/mlp.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/mlp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e40153264bdb86c6fa52a8c68ac87c90075b69da Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/mlp.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bf7602032d36a56eb10300939c095deed6c7321 Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6494419f8663a907cc9b66f797dbe19510bde169 Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc differ diff --git a/src/da2/model/dinov2/attention.py b/src/da2/model/dinov2/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8348acc8d18745a2275f6d76924c8a4e03ce324e --- /dev/null +++ b/src/da2/model/dinov2/attention.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha, memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity() + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + x = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2]) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + if not XFORMERS_AVAILABLE or x.device.type == "cpu": + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/src/da2/model/dinov2/block.py b/src/da2/model/dinov2/block.py new file mode 100644 index 0000000000000000000000000000000000000000..e50619cc568d0a3c91e5630d9bec4a26212150c0 --- /dev/null +++ b/src/da2/model/dinov2/block.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.nn as nn + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +logger = logging.getLogger("dinov2") + +try: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def attn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: torch.Tensor, + residual_func, #: Callable[[torch.Tensor], torch.Tensor], + sample_drop_ratio: float = 0.0, +) -> torch.Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[torch.Tensor], + residual_func, #: Callable[[torch.Tensor, Any], torch.Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> torch.Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls1.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls2.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + return x_list + else: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, torch.Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert ( + XFORMERS_AVAILABLE + ), "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/src/da2/model/dinov2/dino_head.py b/src/da2/model/dinov2/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b --- /dev/null +++ b/src/da2/model/dinov2/dino_head.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/src/da2/model/dinov2/dinovit.py b/src/da2/model/dinov2/dinovit.py new file mode 100644 index 0000000000000000000000000000000000000000..04e679d67ec8024c23f788d225656b621165825b --- /dev/null +++ b/src/da2/model/dinov2/dinovit.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +import math +import torch +import torch.nn as nn +import contextlib +from functools import partial +from typing import Sequence +from .block import ( + Block +) +from .attention import ( + MemEffAttention +) +from .mlp import ( + Mlp +) +from .patch_embed import ( + PatchEmbed +) +from .swiglu_ffn import ( + SwiGLUFFNFused +) + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha, memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class DINOViT(nn.Module): + def __init__( + self, + img_size=518, + patch_size=14, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=1.0, + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=partial(Block, attn_class=MemEffAttention), + ffn_layer="mlp", + block_chunks=0, + output_idx=[6, 12, 18, 24], + num_register_tokens=0, + interpolate_antialias=False, + use_norm=True, + frozen_stages=0, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + self.frozen_stages = frozen_stages + self.patch_size = patch_size + self.output_idx = output_idx + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = nn.Parameter( + torch.zeros(1, max(1, num_register_tokens), embed_dim) + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + def f(): + return nn.Identity() + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=nn.LayerNorm, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = nn.LayerNorm(embed_dim) + self.use_norm = use_norm + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + def interpolate_pos_encoding(self, x, W, H): + previous_dtype = x.dtype + N = self.pos_embed.shape[1] - 1 + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = W // self.patch_size + h0 = H // self.patch_size + + M = int(math.sqrt(N)) + assert N == M * M + kwargs = {} + kwargs["size"] = (w0, h0) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def tokenize(self, x): + _, _, W, H = x.shape + with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext(): + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + dino_pos_embed = self.interpolate_pos_encoding(x, W, H) + x = x + dino_pos_embed + return x + + def forward_features(self, x): + shapes = [val // self.patch_size for val in x.shape[-2:]] + batch_size = x.shape[0] + features = [] + x = self.tokenize(x) + for i, blk in enumerate(self.blocks): + with ( + torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext() + ): + x = blk(x) + features.append(x) + if self.use_norm: + with ( + torch.no_grad() + if self.frozen_stages >= len(self.blocks) + else contextlib.nullcontext() + ): + features = [self.norm(out) for out in features] + features = [out[:, self.num_register_tokens + 1 :] for out in features] + features = [out.reshape(batch_size, *shapes, -1) for out in features] + return features + + def forward(self, *args): + features = self.forward_features(*args) + return features diff --git a/src/da2/model/dinov2/drop_path.py b/src/da2/model/dinov2/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93 --- /dev/null +++ b/src/da2/model/dinov2/drop_path.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/da2/model/dinov2/layer_scale.py b/src/da2/model/dinov2/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..40d18b5427183534d5516652b076f9883a609fc6 --- /dev/null +++ b/src/da2/model/dinov2/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +import torch.nn as nn +from torch import Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/da2/model/dinov2/mlp.py b/src/da2/model/dinov2/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..af598999855d897948142cc986fce82abc9e3b53 --- /dev/null +++ b/src/da2/model/dinov2/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/src/da2/model/dinov2/patch_embed.py b/src/da2/model/dinov2/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a56c02609e67922eb8f859588ef274e5298b55 --- /dev/null +++ b/src/da2/model/dinov2/patch_embed.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/src/da2/model/dinov2/swiglu_ffn.py b/src/da2/model/dinov2/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..e82999e9b09b41cd6aba9edbc4c05d51ab663a1e --- /dev/null +++ b/src/da2/model/dinov2/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/src/da2/model/sphere.py b/src/da2/model/sphere.py new file mode 100644 index 0000000000000000000000000000000000000000..a969bdfffe707052793b9c399f9fdc6f54b96102 --- /dev/null +++ b/src/da2/model/sphere.py @@ -0,0 +1,30 @@ +import torch + + +def get_uv_gird(h, w, device): + pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device) + pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device) + stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()] + grid = torch.stack(stacks, dim=0).float() + grid = grid.to(device).unsqueeze(0) + return grid + + +class Sphere(): + def __init__(self, config, device): + self.config = config + self.device = device + + def get_directions(self, shape): + h, w = shape + uv = get_uv_gird(h, w, device=self.device) + u, v = uv.unbind(dim=1) + width, height = self.config['width'], self.config['height'] + hfov, vfov = self.config['hfov'], self.config['vfov'] + longitude = (u - width / 2) / width * hfov + latitude = (v - height / 2) / height * vfov + x = torch.cos(latitude) * torch.sin(longitude) + z = torch.cos(latitude) * torch.cos(longitude) + y = torch.sin(latitude) + sphere_directions = torch.stack([x, y, z], dim=1) + return sphere_directions diff --git a/src/da2/model/spherevit.py b/src/da2/model/spherevit.py new file mode 100644 index 0000000000000000000000000000000000000000..6927e4944eade89d108f50f4bb14f1ef6ed5aae1 --- /dev/null +++ b/src/da2/model/spherevit.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from math import ( + ceil, + sqrt +) +from huggingface_hub import PyTorchModelHubMixin +import torchvision.transforms.v2.functional as TF +from .dinov2 import DINOViT +from .vit_w_esphere import ViT_w_Esphere +from .sphere import Sphere + + +IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) + +class SphereViT(nn.Module, PyTorchModelHubMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.dino = DINOViT() + self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere']) + feature_slices = self.dino.output_idx + self.feature_slices = list( + zip([0, *feature_slices[:-1]], feature_slices) + ) + self.device = None + + def to(self, *args): + self.device = args[0] + return super().to(*args) + + def forward(self, images): + B, _, H, W = images.shape + current_pixels = H * W + target_pixels = min(self.config['inference']['max_pixels'], + max(self.config['inference']['min_pixels'], current_pixels)) + factor = sqrt(target_pixels / current_pixels) + sphere_config = deepcopy(self.config['spherevit']['sphere']) + sphere_config['width'] *= factor + sphere_config['height'] *= factor + sphere = Sphere(config=sphere_config, device=self.device) + H_new = int(H * factor) + W_new = int(W * factor) + DINO_patch_size = 14 # please see the line 51 of `src/da2/model/dinov2/dinovit.py` (I know it's a little ugly to hardcode it here T_T) + H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size + W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size + images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False) + images = TF.normalize( + images.float(), + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + + sphere_dirs = sphere.get_directions(shape=(H_new, W_new)) + sphere_dirs = sphere_dirs.to(self.device) + sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1) + + features = self.dino(images) + features = [ + features[i:j][-1].contiguous() + for i, j in self.feature_slices + ] + distance = self.vit_w_esphere(images, features, sphere_dirs) + distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False) + distance = distance.squeeze(dim=1) # (b, 1, h, w) -> (b, h, w) + return distance diff --git a/src/da2/model/vit_w_esphere.py b/src/da2/model/vit_w_esphere.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6e34f2aeee0c6fba48c8716c4218c50b3d0135 --- /dev/null +++ b/src/da2/model/vit_w_esphere.py @@ -0,0 +1,224 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from .base import ( + fourier_dimension_expansion, + flatten, + DimensionAligner, + AttentionSeq, + ResidualUpsampler +) + + +class _ViT_w_Esphere(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + num_layers_head: int | list[int] = 4, + dropout: float = 0.0, + kernel_size: int = 7, + layer_scale: float = 1.0, + out_dim: int = 1, + num_prompt_blocks: int = 1, + use_norm: bool = False, + **kwargs, + ) -> None: + super().__init__() + self.out_dim = out_dim + self.hidden_dim = hidden_dim + self.up_sampler = nn.ModuleList([]) + self.pred_head = nn.ModuleList([]) + self.process_features = nn.ModuleList([]) + self.prompt_camera = nn.ModuleList([]) + mult = 2 + self.to_latents = nn.Linear(hidden_dim, hidden_dim) + + for _ in range(4): + self.prompt_camera.append( + AttentionSeq( + num_blocks=num_prompt_blocks, + dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=-1.0, + context_dim=hidden_dim, + ) + ) + + for i, depth in enumerate(num_layers_head): + current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) + next_dim = mult * hidden_dim // int(2 ** (i + 1)) + output_dim = max(next_dim, out_dim) + self.process_features.append( + nn.ConvTranspose2d( + hidden_dim, + current_dim, + kernel_size=max(1, 2 * i), + stride=max(1, 2 * i), + padding=0, + ) + ) + self.up_sampler.append( + ResidualUpsampler( + current_dim, + output_dim=output_dim, + expansion=expansion, + layer_scale=layer_scale, + kernel_size=kernel_size, + num_layers=depth, + use_norm=use_norm, + ) + ) + pred_head = ( + nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)) + if i == len(num_layers_head) - 1 + else nn.Identity() + ) + self.pred_head.append(pred_head) + + self.to_depth_lr = nn.Conv2d( + output_dim, + output_dim // 2, + kernel_size=3, + padding=1, + padding_mode='reflect', + ) + self.to_confidence_lr = nn.Conv2d( + output_dim, + output_dim // 2, + kernel_size=3, + padding=1, + padding_mode='reflect', + ) + self.to_depth_hr = nn.Sequential( + nn.Conv2d( + output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect' + ), + nn.LeakyReLU(), + nn.Conv2d(32, 1, kernel_size=1), + ) + self.to_confidence_hr = nn.Sequential( + nn.Conv2d( + output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect' + ), + nn.LeakyReLU(), + nn.Conv2d(32, 1, kernel_size=1), + ) + + def set_original_shapes(self, shapes: tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + def embed_sphere_dirs(self, sphere_dirs): + sphere_embedding = flatten( + sphere_dirs, old=self.original_shapes, new=self.shapes + ) + # index 0 -> Y + # index 1 -> Z + # index 2 -> X + r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2] + polar = torch.asin(r2) + r3_clipped = r3.abs().clip(min=1e-5) * (2 * (r3 >= 0).int() - 1) + azimuth = torch.atan2(r1, r3_clipped) + # [polar, azimuth] is the angle field + sphere_embedding = torch.stack([polar, azimuth], dim=-1) + # expand the dimension of the angle field to image feature dimensions, via sine-cosine basis embedding + sphere_embedding = fourier_dimension_expansion( + sphere_embedding, + dim=self.hidden_dim, + max_freq=max(self.shapes) // 2, + use_cos=False, + ) + return sphere_embedding + + def condition(self, feat, sphere_embeddings): + conditioned_features = [ + prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings) + for prompter, feature in zip(self.prompt_camera, feat) + ] + return conditioned_features + + def process(self, features_list, sphere_embeddings): + conditioned_features = self.condition(features_list, sphere_embeddings) + init_latents = self.to_latents(conditioned_features[0]) + init_latents = rearrange( + init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1] + ).contiguous() + conditioned_features = [ + rearrange( + x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1] + ).contiguous() + for x in conditioned_features + ] + latents = init_latents + + out_features = [] + # Pyramid-like multi-layer convolutional feature extraction + for i, up in enumerate(self.up_sampler): + latents = latents + self.process_features[i](conditioned_features[i + 1]) + latents = up(latents) + out_features.append(latents) + return out_features + + def prediction_head(self, out_features): + depths = [] + h_out, w_out = out_features[-1].shape[-2:] + for i, (layer, features) in enumerate(zip(self.pred_head, out_features)): + out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if i < len(self.pred_head) - 1: + continue + depths.append(out_depth_features) + out_depth_features = F.interpolate( + out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True + ) + distance = self.to_depth_lr(out_depth_features) + distance = F.interpolate( + distance, size=self.original_shapes, mode='bilinear', align_corners=True + ) + distance = self.to_depth_hr(distance) + return distance + + def forward( + self, + features: list[torch.Tensor], + sphere_dirs: torch.Tensor + ) -> torch.Tensor: + sphere_embeddings = self.embed_sphere_dirs(sphere_dirs) + features = self.process(features, sphere_embeddings) + distance = self.prediction_head(features) + return distance + + +class ViT_w_Esphere(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.dim_aligner = DimensionAligner( + input_dims=config['input_dims'], + hidden_dim=config['hidden_dim'], + ) + self._vit_w_esphere = _ViT_w_Esphere(**config) + + def forward(self, images, features, sphere_dirs) -> torch.Tensor: + _, _, H, W = images.shape + sphere_dirs = sphere_dirs + common_shape = features[0].shape[1:3] + features = self.dim_aligner(features) + sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c') + + self._vit_w_esphere.set_shapes(common_shape) + self._vit_w_esphere.set_original_shapes((H, W)) + logdistance = self._vit_w_esphere( + features=features, + sphere_dirs=sphere_dirs, + ) + + distance = torch.exp(logdistance.clip(min=-8.0, max=8.0) + 2.0) + distance = distance / torch.quantile(distance, 0.98) + return distance diff --git a/src/da2/utils/__init__.py b/src/da2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac61fb325760b9f30f6444fc386d07b9c23022f5 --- /dev/null +++ b/src/da2/utils/__init__.py @@ -0,0 +1,11 @@ +from .base import ( + prepare_to_run +) +from .model import ( + load_model +) + +__all__ = [ + 'prepare_to_run', + 'load_model' +] diff --git a/src/da2/utils/__pycache__/__init__.cpython-312.pyc b/src/da2/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c636b81d1f058165fb6316e7109a4f77b696e7b9 Binary files /dev/null and b/src/da2/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/da2/utils/__pycache__/base.cpython-312.pyc b/src/da2/utils/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c74a8a9801068a53cc8d00100da6c1ef698cb381 Binary files /dev/null and b/src/da2/utils/__pycache__/base.cpython-312.pyc differ diff --git a/src/da2/utils/__pycache__/d2pc.cpython-312.pyc b/src/da2/utils/__pycache__/d2pc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70fd92a61e91bd241a3230635c36ff95f86b1adf Binary files /dev/null and b/src/da2/utils/__pycache__/d2pc.cpython-312.pyc differ diff --git a/src/da2/utils/__pycache__/io.cpython-312.pyc b/src/da2/utils/__pycache__/io.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45ccb3ac4d2a23f52c6dd4fb97c2a13b4a4b259e Binary files /dev/null and b/src/da2/utils/__pycache__/io.cpython-312.pyc differ diff --git a/src/da2/utils/__pycache__/model.cpython-312.pyc b/src/da2/utils/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47e706b0ec91363b1b3ea0f68f0065318f36ffad Binary files /dev/null and b/src/da2/utils/__pycache__/model.cpython-312.pyc differ diff --git a/src/da2/utils/__pycache__/vis.cpython-312.pyc b/src/da2/utils/__pycache__/vis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bebaecbb4e809ca9277277597b264c78623a2d35 Binary files /dev/null and b/src/da2/utils/__pycache__/vis.cpython-312.pyc differ diff --git a/src/da2/utils/base.py b/src/da2/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..148b261d5646947592c6cb00ed96eb8955c5bea6 --- /dev/null +++ b/src/da2/utils/base.py @@ -0,0 +1,56 @@ +import json +import argparse +import os +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ( + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed +) +import logging +from datetime import ( + timedelta, + datetime +) + + +def load_config(config_path): + with open(config_path, 'r') as f: + config = json.load(f) + return config + +def arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--config_path', type=str, required=True) + args = parser.parse_args() + return args + +def prepare_to_run(): + args = arg_parser() + logging.basicConfig( + format='%(asctime)s --> %(message)s', + datefmt='%m/%d %H:%M:%S', + level=logging.INFO, + ) + config = load_config(args.config_path) + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout'])) + version = os.path.basename(args.config_path)[:-5] + output_dir = f'output/{version}_{datetime.now().strftime("%Y%m%d_%H%M%S")}' + if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) + accu_steps = config['accelerator']['accumulation_nsteps'] + accelerator = Accelerator( + gradient_accumulation_steps=accu_steps, + mixed_precision=config['accelerator']['mixed_precision'], + log_with=config['accelerator']['report_to'], + project_config=ProjectConfiguration(project_dir=output_dir), + kwargs_handlers=[kwargs] + ) + logger = get_logger(__name__, log_level='INFO') + config['env']['logger'] = logger + set_seed(config['env']['seed']) + if config['env']['verbose']: + logger.info(f'Version: {version} (from {args.config_path})') + logger.info(f'Output dir: {output_dir}') + logger.info(f'Using {accelerator.num_processes} GPU' + ('s' if accelerator.num_processes > 1 else '')) + return config, accelerator, output_dir diff --git a/src/da2/utils/d2pc.py b/src/da2/utils/d2pc.py new file mode 100644 index 0000000000000000000000000000000000000000..e67c8f4c4387cfde4aac4a791428e4ee27df3742 --- /dev/null +++ b/src/da2/utils/d2pc.py @@ -0,0 +1,76 @@ +import os +import numpy as np +import utils3d +from plyfile import PlyData, PlyElement +from PIL import Image + + +def sphere_uv2dirs(uv: np.ndarray): + theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi + directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1) + return directions + +def save_3d_points(points: np.array, colors: np.array, mask: np.array, save_path: str): + points = points.reshape(-1, 3) + colors = colors.reshape(-1, 3) + mask = mask.reshape(-1, 1) + + vertex_data = np.empty(mask.sum(), dtype=[ + ('x', 'f4'), + ('y', 'f4'), + ('z', 'f4'), + ('red', 'u1'), + ('green', 'u1'), + ('blue', 'u1'), + ]) + vertex_data['x'] = [a for i, a in enumerate(points[:, 0]) if mask[i]] + vertex_data['y'] = [a for i, a in enumerate(points[:, 1]) if mask[i]] + vertex_data['z'] = [a for i, a in enumerate(points[:, 2]) if mask[i]] + vertex_data['red'] = [a for i, a in enumerate(colors[:, 0]) if mask[i]] + vertex_data['green'] = [a for i, a in enumerate(colors[:, 1]) if mask[i]] + vertex_data['blue'] = [a for i, a in enumerate(colors[:, 2]) if mask[i]] + + vertex_element = PlyElement.describe(vertex_data, 'vertex', comments=['vertices with color']) + save_dir = os.path.dirname(save_path) + if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) + PlyData([vertex_element], text=True).write(save_path) + +def colorize_normal(normal: np.ndarray, normal_mask: np.ndarray): + normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8) + normal_mask = np.repeat(np.expand_dims(normal_mask, axis=-1), 3, axis=-1) + normal_mask = normal_mask.astype(np.uint8) + normal_rgb = normal_rgb * normal_mask + return normal_rgb + +def normal_normalize(normal: np.ndarray): + normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True) + normal_norm[normal_norm < 1e-6] = 1e-6 + return normal / normal_norm + +def distance2pointcloud( + distance: np.ndarray, + image: np.ndarray, + mask: np.ndarray, + save_path: str = None, + return_normal: bool = False, + save_distance: bool = False +): + if distance.ndim >= 3: distance = distance.squeeze() + if save_distance: + save_path_dis = save_path.replace('3dpc', 'depth').replace('.ply', '.npy') + save_dir_dis = os.path.dirname(save_path_dis) + if not os.path.exists(save_dir_dis): os.makedirs(save_dir_dis, exist_ok=True) + np.save(save_path_dis, distance) + height, width = distance.shape[:2] + points = distance[:, :, None] * sphere_uv2dirs(utils3d.numpy.image_uv(width=width, height=height)) + save_3d_points(points, image, mask, save_path) + if return_normal: + normal, normal_mask = utils3d.numpy.points_to_normals(points, mask) + normal = normal * np.array([-1, -1, 1]) + normal = normal_normalize(normal) + normal_1 = normal[..., 0] + normal_2 = normal[..., 1] + normal_3 = normal[..., 2] + normal = np.stack([normal_1, normal_3, normal_2], axis=-1) + normal_img = colorize_normal(normal, normal_mask) + return Image.fromarray(normal_img) diff --git a/src/da2/utils/io.py b/src/da2/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..ccaf2dc0034db8ae04d02f6c6a059b98827d1d4c --- /dev/null +++ b/src/da2/utils/io.py @@ -0,0 +1,63 @@ +import os +import cv2 +import torch +import numpy as np +from glob import glob +from PIL import Image + + +def torch_transform(image): + image = image / 255.0 + image = np.transpose(image, (2, 0, 1)) + return image + +def read_cv2_image(image_path): + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + +def read_mask(mask_path, shape): + if not os.path.exists(mask_path): + return np.ones(shape[1:]) > 0 + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask = mask > 0 + return mask + +def tensorize(array, model_dtype, device): + array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0) + return array + +def load_infer_data(config, device): + image_dir = config['inference']['images'] + mask_dir = config['inference']['masks'] + + image_paths = glob(os.path.join(image_dir, '*.png')) + image_paths = sorted(image_paths) + filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths] + cv2_images = [read_cv2_image(image_path) + for image_path in image_paths] + PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images] + images = [torch_transform(cv2_image) for cv2_image in cv2_images] + + mask_paths = [image_path.replace(image_dir, mask_dir) + for image_path in image_paths] + masks = [read_mask(mask_path, images[i].shape) + for (i, mask_path) in enumerate(mask_paths)] + + model_dtype = config['spherevit']['dtype'] + images = [tensorize(image, model_dtype, device) for image in images] + + infer_data = { + 'images': { + 'PIL': PIL_images, + 'cv2': cv2_images, + 'torch': images + }, + 'masks': masks, + 'filenames': filenames, + 'size': len(images) + } + if config['env']['verbose']: + s = 's' if len(images) > 1 else '' + config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}') + return infer_data diff --git a/src/da2/utils/model.py b/src/da2/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0088ab1e0a1dd995767ea8c107a40b6fb1c2e836 --- /dev/null +++ b/src/da2/utils/model.py @@ -0,0 +1,15 @@ +import torch +from ..model.spherevit import SphereViT + + +def load_model(config, accelerator): + model = SphereViT.from_pretrained('haodongli/DA-2', config=config) + model = model.to(accelerator.device) + torch.cuda.empty_cache() + model = accelerator.prepare(model) + if accelerator.num_processes > 1: + model = model.module + if config['env']['verbose']: + config['env']['logger'].info(f'Model\'s dtype: {next(model.parameters()).dtype}.') + config['spherevit']['dtype'] = next(model.parameters()).dtype + return model diff --git a/src/da2/utils/vis.py b/src/da2/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..6cbeab57788963cd0b7ff678979f8723d80911b9 --- /dev/null +++ b/src/da2/utils/vis.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +import numpy as np +import matplotlib +import cv2 + + +def concatenate_images(*image_lists): + max_width = 0 + total_height = 0 + row_widths = [] + row_heights = [] + + for i, image_list in enumerate(image_lists): + width = sum(img.width for img in image_list) + max_width = max(max_width, width) + row_widths.append(width) + # Assuming all images in the list have the same height + height = image_list[0].height + total_height += height + row_heights.append(height) + + new_image = Image.new('RGB', (max_width, total_height)) + y_offset = 0 + for i, image_list in enumerate(image_lists): + x_offset = 0 + for img in image_list: + new_image.paste(img, (x_offset, y_offset)) + x_offset += img.width + y_offset += row_heights[i] + return new_image + +def colorize_distance(distance, mask, cmap='Spectral'): + if distance.ndim >= 3: distance = distance.squeeze() + cm = matplotlib.colormaps[cmap] + valid_distance = distance[mask] + max_distance = np.quantile(valid_distance, 0.98) + min_distance = np.quantile(valid_distance, 0.02) + distance[~mask] = max_distance + distance = ((distance - min_distance) / (max_distance - min_distance)) + distance = np.clip(distance, 0, 1) + img_colored_np = cm(distance, bytes=False)[:, :, 0:3] + distance_colored = (img_colored_np * 255).astype(np.uint8) + return Image.fromarray(distance_colored) diff --git a/src/pyproject.toml b/src/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..933ef0b55c7b60e158160965754924062a83eacd --- /dev/null +++ b/src/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "da2" +version = "0.1.0" +description = "For the implementation of DA^2: Depth Anything in Any Direction" +authors = [ + { name="H. Li", email="hal211@ucsd.edu" } +] +dependencies = [ + "torch==2.5.0", + "torchvision==0.20.0", + "torchaudio==2.5.0", + "xformers==0.0.28.post2", + "diffusers==0.32.0", + "tensorboard==2.18.0", + "utils3d @ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900", + "opencv-python==4.12.0.88", + "gradio==5.49.0", + "gradio-client==1.13.3", + "gradio-imageslider==0.0.20", + "accelerate==1.1.1", + "omegaconf==2.3.0", + "tabulate==0.9.0", + "einops==0.8.0", + "timm==1.0.15", + "trimesh==4.5.2", + "transformers==4.46.3", + "matplotlib==3.9.2" +] + +[tool.setuptools.packages.find] +where = ["."] \ No newline at end of file