diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..19f846cfea6e8d0869d82423a206db6e52aaa6ff 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..831669619f694d7dc1d9a82ed9ae0c235decb783
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+cache/
+output/
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2952c5879bd150fda1f431e62f125ea35464a4d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,165 @@
+import os
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import (
+ InitProcessGroupKwargs,
+ ProjectConfiguration,
+ set_seed
+)
+import torch
+from contextlib import nullcontext
+import trimesh
+import gradio as gr
+from gradio_imageslider import ImageSlider
+from da2.utils.base import load_config
+from da2.utils.model import load_model
+from da2.utils.io import (
+ read_cv2_image,
+ torch_transform,
+ tensorize
+)
+from da2.utils.vis import colorize_distance
+from da2.utils.d2pc import distance2pointcloud
+from datetime import (
+ timedelta,
+ datetime
+)
+import cv2
+import numpy as np
+
+last_glb_path = None
+
+def prepare_to_run_demo():
+ config = load_config('configs/infer.json')
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
+ output_dir = f'output/infer'
+ if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
+ accu_steps = config['accelerator']['accumulation_nsteps']
+ accelerator = Accelerator(
+ gradient_accumulation_steps=accu_steps,
+ mixed_precision=config['accelerator']['mixed_precision'],
+ log_with=config['accelerator']['report_to'],
+ project_config=ProjectConfiguration(project_dir=output_dir),
+ kwargs_handlers=[kwargs]
+ )
+ logger = get_logger(__name__, log_level='INFO')
+ config['env']['logger'] = logger
+ set_seed(config['env']['seed'])
+ return config, accelerator
+
+def read_mask_demo(mask_path, shape):
+ if mask_path is None:
+ return np.ones(shape[1:]) > 0
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+ mask = mask > 0
+ return mask
+
+def load_infer_data_demo(image, mask, model_dtype, device):
+ cv2_image = read_cv2_image(image)
+ image = torch_transform(cv2_image)
+ mask = read_mask_demo(mask, image.shape)
+ image = tensorize(image, model_dtype, device)
+ return image, cv2_image, mask
+
+def ply2glb(ply_path, glb_path):
+ pcd = trimesh.load(ply_path)
+ points = np.asarray(pcd.vertices)
+ colors = np.asarray(pcd.visual.vertex_colors)
+ cloud = trimesh.points.PointCloud(vertices=points, colors=colors)
+ cloud.export(glb_path)
+ os.remove(ply_path)
+
+def fn(image_path, mask_path):
+ global last_glb_path
+ config, accelerator = prepare_to_run_demo()
+ model = load_model(config, accelerator)
+ image, cv2_image, mask = load_infer_data_demo(image_path, mask_path,
+ model_dtype=config['spherevit']['dtype'], device=accelerator.device)
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+ with autocast_ctx, torch.no_grad():
+ distance = model(image).cpu().numpy()[0]
+ if last_glb_path is not None:
+ os.remove(last_glb_path)
+ distance_vis = colorize_distance(distance, mask)
+ save_path = f'cache/tmp_{datetime.now().strftime("%Y%m%d_%H%M%S")}.glb'
+ last_glb_path = save_path
+ normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False)
+ ply2glb(save_path.replace('.glb', '.ply'), save_path)
+ return save_path, [distance_vis, normal_image]
+
+inputs = [
+ gr.Image(label="Input Image", type="filepath"),
+ gr.Image(label="Input Mask", type="filepath"),
+]
+outputs = [
+ gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Point Cloud"),
+ gr.ImageSlider(
+ label="Output Depth / Normal (transformed from the depth)",
+ type="pil",
+ slider_position=75,
+ )
+]
+
+demo = gr.Interface(
+ fn=fn,
+ title="DA2: Depth Anything in Any Direction",
+ description="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Please consider starring ★ our GitHub Repo if you find this demo useful!
+
+ Note: the "Input Mask" is optional, all pixels are assumed to be valid if mask is None.
+ """,
+ inputs=inputs,
+ outputs=outputs,
+ examples=[
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a1.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a2.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a3.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a4.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/b0.png"),
+ os.path.join(os.path.dirname(__file__), "assets/masks/b0.png")],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/b1.png"),
+ os.path.join(os.path.dirname(__file__), "assets/masks/b1.png")],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a5.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a6.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a7.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a8.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/b2.png"),
+ os.path.join(os.path.dirname(__file__), "assets/masks/b2.png")],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/b3.png"),
+ os.path.join(os.path.dirname(__file__), "assets/masks/b3.png")],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a9.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a10.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a11.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/a0.png"), None],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/b4.png"),
+ os.path.join(os.path.dirname(__file__), "assets/masks/b4.png")],
+ [os.path.join(os.path.dirname(__file__), "assets/demos/b5.png"),
+ os.path.join(os.path.dirname(__file__), "assets/masks/b5.png")],
+ ],
+ examples_per_page=20
+)
+
+demo.launch(
+ server_name="0.0.0.0",
+ server_port=6381,
+)
diff --git a/assets/badges/icon2.png b/assets/badges/icon2.png
new file mode 100644
index 0000000000000000000000000000000000000000..847382c831464926fb2b7170b5868930ea86fc34
--- /dev/null
+++ b/assets/badges/icon2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d254fc5009dd41b367790aa9e45f05770b81ed62c67d8cc713bee4608567218f
+size 6765
diff --git a/assets/badges/teaser.jpg b/assets/badges/teaser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..395a804413725d527c02a1b8bba4935a98812c18
--- /dev/null
+++ b/assets/badges/teaser.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c6786218d0a17115e6ed1320434b2b47101290a7e244f2eed1ebe70e4822464
+size 1201407
diff --git a/assets/demos/a0.png b/assets/demos/a0.png
new file mode 100644
index 0000000000000000000000000000000000000000..267653c09acf2f8a26be91d6fbd4c011eda89acd
--- /dev/null
+++ b/assets/demos/a0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eedc66f98cf0a949602f691c3eed51511ae520cf8f63674abe542741ba6090b8
+size 744091
diff --git a/assets/demos/a1.png b/assets/demos/a1.png
new file mode 100644
index 0000000000000000000000000000000000000000..815eadf96f883a47a2b4fdf171d5cbef8d0c9017
--- /dev/null
+++ b/assets/demos/a1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:906f336ab4c6561ee85b9cb883a6aa34cf11289fc86b6a4e4382baed56981aa7
+size 821703
diff --git a/assets/demos/a10.png b/assets/demos/a10.png
new file mode 100644
index 0000000000000000000000000000000000000000..da37cd6ddab7133ccb9ce12b2d2660bc8228915e
--- /dev/null
+++ b/assets/demos/a10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6d058aef9322964f5d36de90ab91470e283acab248604bcd488a43c680a9e7d
+size 881818
diff --git a/assets/demos/a11.png b/assets/demos/a11.png
new file mode 100644
index 0000000000000000000000000000000000000000..12620f1f8717c787c482c0b0bbc40c2794d3ae48
--- /dev/null
+++ b/assets/demos/a11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45af8c71b8d44880503b5da1b5f67a0d5638860b9f9149cae7d16a3a3975d090
+size 848394
diff --git a/assets/demos/a2.png b/assets/demos/a2.png
new file mode 100644
index 0000000000000000000000000000000000000000..2a25f8f6fb8b191147612ef4b114aad93802f436
--- /dev/null
+++ b/assets/demos/a2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fa931d70c6220cec0b56a9cdf651f12fa35436d937cd2cf481d10dddb2a114e
+size 809628
diff --git a/assets/demos/a3.png b/assets/demos/a3.png
new file mode 100644
index 0000000000000000000000000000000000000000..5cfb653c14bc04496ec99a06f0309e922f9ff8df
--- /dev/null
+++ b/assets/demos/a3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a85573ac5d51a261d82b23475488e769bd9b3e392948e60e6dc73f0c7ace762b
+size 854468
diff --git a/assets/demos/a4.png b/assets/demos/a4.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7d9899b683013c1277dc1bf274e59c106259fbc
--- /dev/null
+++ b/assets/demos/a4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0a544ec4b542c59f1fbfaf99f86eb60b4c0dbce7c8e4b1bac9e6e23e889c7ec
+size 812626
diff --git a/assets/demos/a5.png b/assets/demos/a5.png
new file mode 100644
index 0000000000000000000000000000000000000000..716e2c9ad54f622fc73d622dc269db9d35f33915
--- /dev/null
+++ b/assets/demos/a5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e36ed78b74223eae24f8c85f1cdab00d1a3a5b494fec807240cb7d3427fad87
+size 847578
diff --git a/assets/demos/a6.png b/assets/demos/a6.png
new file mode 100644
index 0000000000000000000000000000000000000000..4938bd13ce9be9ba95427b707c644ca6edbb3a3e
--- /dev/null
+++ b/assets/demos/a6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e48031fcd3e5a84e4ea4513a23e2ec8150f8ec3fbdae1d4b2d51fc67ac588fe6
+size 818477
diff --git a/assets/demos/a7.png b/assets/demos/a7.png
new file mode 100644
index 0000000000000000000000000000000000000000..f94efe1f72d5c0422246f5d236617f0d6c69262a
--- /dev/null
+++ b/assets/demos/a7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12b99fdddea8eefb6885114bd386fc4fad0484e13c85c88364a43396f9cef3f9
+size 904680
diff --git a/assets/demos/a8.png b/assets/demos/a8.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb1f55614094f39a16c40221801ccbeafc9a506f
--- /dev/null
+++ b/assets/demos/a8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b29df5b6294742acc43d8ce41073b335e98024459273b77d9b943fd3583ac35
+size 784328
diff --git a/assets/demos/a9.png b/assets/demos/a9.png
new file mode 100644
index 0000000000000000000000000000000000000000..faec9983f223633853a0eedd5d4399f0ff108f76
--- /dev/null
+++ b/assets/demos/a9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba92bf3adf1d1b2a775d5b0f895a16876159fc1a43d98328c923fdc994d6e346
+size 910249
diff --git a/assets/demos/b0.png b/assets/demos/b0.png
new file mode 100644
index 0000000000000000000000000000000000000000..c6ec9dd9975fe81f94760b38ca84eab425ab1f6b
--- /dev/null
+++ b/assets/demos/b0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b610ae826372778853553810ef0e07e4f91d8507549dc0f5f32eca038348a37
+size 850392
diff --git a/assets/demos/b1.png b/assets/demos/b1.png
new file mode 100644
index 0000000000000000000000000000000000000000..27d7c6e8ec92cf3a476966d65e1325c2230f1964
--- /dev/null
+++ b/assets/demos/b1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2df3207be859cf8524e9a00a76efb606e626ca4cc9dbd81178fe24de43a6b97b
+size 798128
diff --git a/assets/demos/b2.png b/assets/demos/b2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4beb346587164c6867c2859d848106aa5f02522d
--- /dev/null
+++ b/assets/demos/b2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:790218133cd507f1f9ca65fcdff60f74325df39ebd0df1d5b6e6261a8dfd29a8
+size 863217
diff --git a/assets/demos/b3.png b/assets/demos/b3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3acd2114400b10a557021a5904e906cf00b8942b
--- /dev/null
+++ b/assets/demos/b3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:843b680077e114451285efc6536e811739cbbab07ade423459a5bc24e747455f
+size 650671
diff --git a/assets/demos/b4.png b/assets/demos/b4.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ba15860a680613beb8ea710414407470f26a6e8
--- /dev/null
+++ b/assets/demos/b4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5615e49fa1bea5ee049a66bbe577d48dd63f441e86a4ae5b225136e7e2295187
+size 804398
diff --git a/assets/demos/b5.png b/assets/demos/b5.png
new file mode 100644
index 0000000000000000000000000000000000000000..cccf4ea0427459d11d78447e7ece5a2dd89304e0
--- /dev/null
+++ b/assets/demos/b5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7957ee9e54dd6b61b74014412ece3de7bbe999ae0c0be41c4d762d62d8352656
+size 669137
diff --git a/assets/masks/b0.png b/assets/masks/b0.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e6417f05a4b4dcebaab35d11352e5fa2e257430
--- /dev/null
+++ b/assets/masks/b0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7495c6c7672f1b0551f5640a0344a3730744cfa535697307afa917fbf46466ad
+size 6993
diff --git a/assets/masks/b1.png b/assets/masks/b1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7d1fd480be44bfd86c4002186f9cb0c69eb4cda
--- /dev/null
+++ b/assets/masks/b1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1aea3b6a9a99adbcdb71fcbc9eb5c5f18fbdc36b38829d7ba972183a7ec564e3
+size 5357
diff --git a/assets/masks/b2.png b/assets/masks/b2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a48bd5586e7dd2e05a4b10e093a7ab8118837b5
--- /dev/null
+++ b/assets/masks/b2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4360d8523cb2309b29ed549c6a7c84dd0d6a3ca5f55720ae43b728668dfe6c9b
+size 7703
diff --git a/assets/masks/b3.png b/assets/masks/b3.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a332ee4df7d8f4420fe536dedd5e46884b00667
--- /dev/null
+++ b/assets/masks/b3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d1e6f1d40d8f9e8e5593bf3f5fe67967528b8afcbfaf605658f19004edbdb10d
+size 4568
diff --git a/assets/masks/b4.png b/assets/masks/b4.png
new file mode 100644
index 0000000000000000000000000000000000000000..8d63e68b058bd3594d2f464bdad92a9bb4c6ee2e
--- /dev/null
+++ b/assets/masks/b4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a2a1018ad95749d83193fc0f333e1af04de119857e2564c5fbefa41301f2226
+size 5428
diff --git a/assets/masks/b5.png b/assets/masks/b5.png
new file mode 100644
index 0000000000000000000000000000000000000000..acf9f30cf3d75c007ba3e1fe04e84a319d9ca7c8
--- /dev/null
+++ b/assets/masks/b5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c38cca29eec4baaeb7b765f595f28d13e1fcaf7707bed7ad83277b12eee1f504
+size 4883
diff --git a/configs/accelerate/0.yaml b/configs/accelerate/0.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6f7938e128402b64e2690255025224e303654b52
--- /dev/null
+++ b/configs/accelerate/0.yaml
@@ -0,0 +1,16 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: 'NO'
+downcast_bf16: 'no'
+gpu_ids: '0'
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/configs/infer.json b/configs/infer.json
new file mode 100755
index 0000000000000000000000000000000000000000..55ecd9accb3d2562cb1a0c5e236bb49bcf0f1c65
--- /dev/null
+++ b/configs/infer.json
@@ -0,0 +1,39 @@
+{
+ "env": {
+ "seed": 42,
+ "verbose": true
+ },
+ "accelerator": {
+ "report_to": ["tensorboard"],
+ "mixed_precision": "fp16",
+ "accumulation_nsteps": 4,
+ "timeout": 36000
+ },
+ "inference": {
+ "images": "assets/demos",
+ "masks": "assets/masks",
+ "min_pixels": 580000,
+ "max_pixels": 620000
+ },
+ "spherevit": {
+ "vit_w_esphere": {
+ "input_dims": [1024, 1024, 1024, 1024],
+ "hidden_dim": 512,
+ "num_heads": 8,
+ "expansion": 4,
+ "num_layers_head": [2, 2, 2],
+ "dropout": 0.0,
+ "layer_scale": 0.0001,
+ "out_dim": 64,
+ "kernel_size": 3,
+ "num_prompt_blocks": 1,
+ "use_norm": false
+ },
+ "sphere": {
+ "width": 1092,
+ "height": 546,
+ "hfov": 6.2832,
+ "vfov": 3.1416
+ }
+ }
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3cceff4cade5de6cbc69eed216433351c475866b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1 @@
+pip install -e src/
\ No newline at end of file
diff --git a/src/da2.egg-info/PKG-INFO b/src/da2.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..8a628b59766e52f98401f3e400f59e7b61280ecd
--- /dev/null
+++ b/src/da2.egg-info/PKG-INFO
@@ -0,0 +1,23 @@
+Metadata-Version: 2.4
+Name: da2
+Version: 0.1.0
+Summary: For the implementation of DA^2: Depth Anything in Any Direction
+Author-email: "H. Li"
+Requires-Dist: torch==2.5.0
+Requires-Dist: torchvision==0.20.0
+Requires-Dist: torchaudio==2.5.0
+Requires-Dist: xformers==0.0.28.post2
+Requires-Dist: diffusers==0.32.0
+Requires-Dist: tensorboard==2.18.0
+Requires-Dist: utils3d@ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900
+Requires-Dist: opencv-python==4.12.0.88
+Requires-Dist: gradio==5.49.0
+Requires-Dist: gradio-client==1.13.3
+Requires-Dist: gradio-imageslider==0.0.20
+Requires-Dist: accelerate==1.1.1
+Requires-Dist: omegaconf==2.3.0
+Requires-Dist: tabulate==0.9.0
+Requires-Dist: einops==0.8.0
+Requires-Dist: timm==1.0.15
+Requires-Dist: trimesh==4.5.2
+Requires-Dist: transformers==4.46.3
diff --git a/src/da2.egg-info/SOURCES.txt b/src/da2.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d75e99d41d9df2936757319c33ca8f10a772280e
--- /dev/null
+++ b/src/da2.egg-info/SOURCES.txt
@@ -0,0 +1,28 @@
+pyproject.toml
+da2/__init__.py
+da2.egg-info/PKG-INFO
+da2.egg-info/SOURCES.txt
+da2.egg-info/dependency_links.txt
+da2.egg-info/requires.txt
+da2.egg-info/top_level.txt
+da2/model/__init__.py
+da2/model/base.py
+da2/model/sphere.py
+da2/model/spherevit.py
+da2/model/vit_w_esphere.py
+da2/model/dinov2/__init__.py
+da2/model/dinov2/attention.py
+da2/model/dinov2/block.py
+da2/model/dinov2/dino_head.py
+da2/model/dinov2/dinovit.py
+da2/model/dinov2/drop_path.py
+da2/model/dinov2/layer_scale.py
+da2/model/dinov2/mlp.py
+da2/model/dinov2/patch_embed.py
+da2/model/dinov2/swiglu_ffn.py
+da2/utils/__init__.py
+da2/utils/base.py
+da2/utils/d2pc.py
+da2/utils/io.py
+da2/utils/model.py
+da2/utils/vis.py
\ No newline at end of file
diff --git a/src/da2.egg-info/dependency_links.txt b/src/da2.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/da2.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/src/da2.egg-info/requires.txt b/src/da2.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..313b5fd2bed588a94da9f04276a28a465976801b
--- /dev/null
+++ b/src/da2.egg-info/requires.txt
@@ -0,0 +1,18 @@
+torch==2.5.0
+torchvision==0.20.0
+torchaudio==2.5.0
+xformers==0.0.28.post2
+diffusers==0.32.0
+tensorboard==2.18.0
+utils3d@ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900
+opencv-python==4.12.0.88
+gradio==5.49.0
+gradio-client==1.13.3
+gradio-imageslider==0.0.20
+accelerate==1.1.1
+omegaconf==2.3.0
+tabulate==0.9.0
+einops==0.8.0
+timm==1.0.15
+trimesh==4.5.2
+transformers==4.46.3
diff --git a/src/da2.egg-info/top_level.txt b/src/da2.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e9f3aca6d46fc8bd68eef426cddaa20a40de4253
--- /dev/null
+++ b/src/da2.egg-info/top_level.txt
@@ -0,0 +1 @@
+da2
diff --git a/src/da2/__init__.py b/src/da2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb1dddc3c657f4ce1b627593983469f7fbfd94b
--- /dev/null
+++ b/src/da2/__init__.py
@@ -0,0 +1,25 @@
+from .utils.base import (
+ prepare_to_run
+)
+from .utils.model import (
+ load_model
+)
+from .utils.io import (
+ load_infer_data
+)
+from .utils.vis import (
+ colorize_distance,
+ concatenate_images
+)
+from .utils.d2pc import (
+ distance2pointcloud
+)
+
+__all__ = [
+ 'prepare_to_run',
+ 'load_model',
+ 'load_infer_data',
+ 'colorize_distance',
+ 'concatenate_images',
+ 'distance2pointcloud'
+]
diff --git a/src/da2/__pycache__/__init__.cpython-312.pyc b/src/da2/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82b6c82f97676a7dfc14c2e303e9a83e818ec09e
Binary files /dev/null and b/src/da2/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src/da2/model/__init__.py b/src/da2/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..df880766b1a941dbe4e413c65aae041e5980c3e9
--- /dev/null
+++ b/src/da2/model/__init__.py
@@ -0,0 +1,11 @@
+from .spherevit import (
+ SphereViT
+)
+from .vit_w_esphere import (
+ ViT_w_Esphere
+)
+
+__all__ = [
+ 'SphereViT',
+ 'ViT_w_Esphere',
+]
diff --git a/src/da2/model/__pycache__/__init__.cpython-312.pyc b/src/da2/model/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e49302a36e05a435a3381fa203f125e421359d48
Binary files /dev/null and b/src/da2/model/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src/da2/model/__pycache__/base.cpython-312.pyc b/src/da2/model/__pycache__/base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5ab12b142a2a78e3a7c15ed35f7a2288b78f03d
Binary files /dev/null and b/src/da2/model/__pycache__/base.cpython-312.pyc differ
diff --git a/src/da2/model/__pycache__/sphere.cpython-312.pyc b/src/da2/model/__pycache__/sphere.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24317b459f977136fdf0601bc0cd61787ff132c8
Binary files /dev/null and b/src/da2/model/__pycache__/sphere.cpython-312.pyc differ
diff --git a/src/da2/model/__pycache__/spherevit.cpython-312.pyc b/src/da2/model/__pycache__/spherevit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..baa794d496bf1f6153affb58840f9ac2709feeb7
Binary files /dev/null and b/src/da2/model/__pycache__/spherevit.cpython-312.pyc differ
diff --git a/src/da2/model/__pycache__/vit_w_esphere.cpython-312.pyc b/src/da2/model/__pycache__/vit_w_esphere.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6743cfac8ca352235bd0ef9e1ec44dd81ce58d54
Binary files /dev/null and b/src/da2/model/__pycache__/vit_w_esphere.cpython-312.pyc differ
diff --git a/src/da2/model/base.py b/src/da2/model/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d467f3184133b590f76af9ff453386719d0736
--- /dev/null
+++ b/src/da2/model/base.py
@@ -0,0 +1,393 @@
+import torch
+import torch.nn as nn
+from math import log2, pi
+from typing import Tuple
+import torch.nn.functional as F
+from einops import rearrange
+from functools import partial
+
+
+def fourier_dimension_expansion(
+ x: torch.Tensor,
+ dim: int = 512,
+ max_freq: int = 64,
+ use_cos: bool = True,
+ use_log: bool = True,
+):
+ device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
+ # input_dim: 2
+ num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
+ # num_bands = 512 // 2 = 256
+ if use_log:
+ scales = 2.0 ** torch.linspace(
+ 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
+ )
+ else:
+ scales = torch.linspace(
+ 1.0, max_freq / 2, num_bands, device=device, dtype=dtype
+ )
+ x = x.unsqueeze(-1)
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
+ x = x * scales * pi
+ x = torch.cat(
+ (
+ [x.sin(), x.cos()]
+ if use_cos
+ else [
+ x.sin(),
+ ]
+ ),
+ dim=-1,
+ )
+ x = x.flatten(-2)
+ return x
+
+def flatten(
+ flat_tensor: torch.Tensor,
+ old: Tuple[int, int],
+ new: Tuple[int, int],
+) -> torch.Tensor:
+ if old[0] == new[0] and old[1] == new[1]:
+ return flat_tensor
+ tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
+ 0, 3, 1, 2
+ ) # b c h w
+ tensor_interp = F.interpolate(
+ tensor,
+ size=(new[0], new[1]),
+ mode='nearest',
+ )
+ flat_tensor_interp = tensor_interp.view(
+ flat_tensor.shape[0], -1, new[0] * new[1]
+ ).permute(
+ 0, 2, 1
+ ) # b (h w) c
+ return flat_tensor_interp.contiguous()
+
+
+class DimensionAligner(nn.Module):
+ def __init__(self, input_dims: list[int], hidden_dim: int):
+ super().__init__()
+ self.aligners = nn.ModuleList([])
+ self.num_chunks = len(input_dims)
+ self.checkpoint = True
+ for input_dim in input_dims:
+ self.aligners.append(nn.Linear(input_dim, hidden_dim))
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ outs = [self.aligners[i](x) for i, x in enumerate(xs)]
+ return outs
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: float | torch.Tensor = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+class SwiGLU(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, gates = x.chunk(2, dim=-1)
+ return x * F.silu(gates)
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ gated: bool = False,
+ output_dim: int | None = None,
+ ):
+ super().__init__()
+ if gated:
+ expansion = int(expansion * 2 / 3)
+ hidden_dim = int(input_dim * expansion)
+ output_dim = default(output_dim, input_dim)
+ self.norm = nn.LayerNorm(input_dim)
+ self.proj1 = nn.Linear(input_dim, hidden_dim)
+ self.proj2 = nn.Linear(hidden_dim, output_dim)
+ self.act = nn.GELU() if not gated else SwiGLU()
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x)
+ x = self.proj1(x)
+ x = self.act(x)
+ x = self.proj2(x)
+ x = self.dropout(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ detach_query: bool = False,
+ residual_ls: bool = False,
+ ):
+ super().__init__()
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.hidden_dim = dim
+ context_dim = dim if context_dim is None else context_dim
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
+ self.kv = nn.Linear(context_dim, dim * 2, bias=False)
+ self.q = nn.Linear(dim, dim, bias=False)
+ self.norm_attnx = nn.LayerNorm(dim)
+ self.norm_attnctx = nn.LayerNorm(context_dim)
+ self.cosine = cosine
+ self.out = nn.Linear(dim, dim, bias=False)
+ self.ls1_1 = (
+ LayerScale(dim, layer_scale)
+ if layer_scale > 0.0 and not residual_ls
+ else nn.Identity()
+ )
+ self.ls1_2 = (
+ LayerScale(dim, layer_scale)
+ if layer_scale > 0.0 and residual_ls
+ else nn.Identity()
+ )
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+ self.detach_query = detach_query
+
+ def attn(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if self.detach_query:
+ x = x.detach()
+ x = self.norm_attnx(x)
+ context = self.norm_attnctx(context)
+ k, v = rearrange(
+ self.kv(context), 'b n (kv h d) -> b h n d kv', h=self.num_heads, kv=2
+ ).unbind(dim=-1)
+ q = rearrange(self.q(x), 'b n (h d) -> b h n d', h=self.num_heads)
+
+ if rope is not None:
+ q = rope(q.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
+ k = rope(k.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
+ else:
+ if pos_embed is not None:
+ pos_embed = rearrange(
+ pos_embed, 'b n (h d) -> b h n d', h=self.num_heads
+ )
+ q = q + pos_embed
+ if pos_embed_context is not None:
+ pos_embed_context = rearrange(
+ pos_embed_context, 'b n (h d) -> b h n d', h=self.num_heads
+ )
+ k = k + pos_embed_context
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
+
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+ )
+ x = rearrange(x, 'b h n d -> b n (h d)')
+ x = self.out(x)
+ return x
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ attn_bias: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ context = x if context is None else context
+ x = self.ls1_1(
+ self.attn(
+ x,
+ rope=rope,
+ rope_pos=rope_pos,
+ attn_bias=attn_bias,
+ context=context,
+ pos_embed=pos_embed,
+ pos_embed_context=pos_embed_context,
+ )
+ ) + self.ls1_2(x)
+ x = self.ls2(self.mlp(x)) + x
+ return x
+
+
+class AttentionSeq(nn.Module):
+ def __init__(
+ self,
+ num_blocks: int,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ detach_query: bool = False,
+ residual_ls: bool = False,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ AttentionBlock(
+ dim=dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ cosine=cosine,
+ gated=gated,
+ layer_scale=layer_scale,
+ context_dim=context_dim,
+ detach_query=detach_query,
+ residual_ls=residual_ls,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ attn_bias: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ for layer in self.layers:
+ x = layer(
+ x,
+ context=context,
+ pos_embed=pos_embed,
+ pos_embed_context=pos_embed_context,
+ attn_bias=attn_bias,
+ rope=rope,
+ rope_pos=rope_pos,
+ )
+ return x
+
+
+class ResidualConvNet(nn.Module):
+ def __init__(
+ self,
+ dim,
+ kernel_size: int = 3,
+ padding_mode: str = 'zeros',
+ dilation: int = 1,
+ layer_scale: float = 1.0,
+ use_norm: bool = False,
+ ):
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ )
+ self.conv2 = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ )
+ self.activation = nn.LeakyReLU()
+ self.gamma = (
+ nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1))
+ if layer_scale > 0.0
+ else 1.0
+ )
+ self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
+ self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
+
+ def forward(self, x):
+ out = self.activation(x)
+ out = self.conv1(out)
+ out = self.norm1(out)
+ out = self.activation(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ return self.gamma * out + x
+
+
+class ResidualUpsampler(nn.Module):
+ def __init__(
+ self,
+ hidden_dim,
+ output_dim: int = None,
+ num_layers: int = 2,
+ kernel_size: int = 3,
+ layer_scale: float = 1.0,
+ padding_mode: str = 'zeros',
+ use_norm: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ output_dim = output_dim if output_dim is not None else hidden_dim // 2
+ self.convs = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.convs.append(
+ ResidualConvNet(
+ hidden_dim,
+ kernel_size=kernel_size,
+ layer_scale=layer_scale,
+ padding_mode=padding_mode,
+ use_norm=use_norm,
+ )
+ )
+ self.up = nn.Sequential(
+ nn.Conv2d(
+ hidden_dim,
+ output_dim,
+ kernel_size=1,
+ padding=0,
+ padding_mode=padding_mode,
+ ),
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ )
+
+ def forward(self, x: torch.Tensor):
+ for conv in self.convs:
+ x = conv(x)
+ x = self.up(x)
+ return x
diff --git a/src/da2/model/dinov2/__init__.py b/src/da2/model/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..89f2115deeced559d02bd871f98ec86e9ec6d97c
--- /dev/null
+++ b/src/da2/model/dinov2/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .dinovit import (
+ DINOViT
+)
+
+__all__ = [
+ 'DINOViT'
+]
diff --git a/src/da2/model/dinov2/__pycache__/__init__.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a3b0109aacc0704664188faedff5938ebca3493
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/attention.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/attention.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f27213c6558f754ed1691024894be5b0f2a0377f
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/attention.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/block.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/block.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..834d0c242bf193fbf4bb0b091c3b51518c67894d
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/block.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..843708a72c84d53570064acaf2b195930f6c5500
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c184bb73a9b53af7d9db2636b939be7df22857c
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f949b8580fe05d570db78412686b7ff47df7501a
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/mlp.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/mlp.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e40153264bdb86c6fa52a8c68ac87c90075b69da
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/mlp.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bf7602032d36a56eb10300939c095deed6c7321
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc b/src/da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6494419f8663a907cc9b66f797dbe19510bde169
Binary files /dev/null and b/src/da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc differ
diff --git a/src/da2/model/dinov2/attention.py b/src/da2/model/dinov2/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..8348acc8d18745a2275f6d76924c8a4e03ce324e
--- /dev/null
+++ b/src/da2/model/dinov2/attention.py
@@ -0,0 +1,79 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha, memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity()
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ x = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2])
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ if not XFORMERS_AVAILABLE or x.device.type == "cpu":
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/src/da2/model/dinov2/block.py b/src/da2/model/dinov2/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..e50619cc568d0a3c91e5630d9bec4a26212150c0
--- /dev/null
+++ b/src/da2/model/dinov2/block.py
@@ -0,0 +1,280 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Any, Callable, Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+logger = logging.getLogger("dinov2")
+
+try:
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: torch.Tensor,
+ residual_func, #: Callable[[torch.Tensor], torch.Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> torch.Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ else:
+ x_plus_residual = scaled_index_add(
+ x,
+ brange,
+ residual.to(dtype=x.dtype),
+ scaling=scaling_vector,
+ alpha=residual_scale_factor,
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = (
+ [b.shape[0] for b in branges]
+ if branges is not None
+ else [x.shape[0] for x in x_list]
+ )
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+ 1, -1, x_list[0].shape[-1]
+ )
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[torch.Tensor],
+ residual_func, #: Callable[[torch.Tensor, Any], torch.Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> torch.Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+ ]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(
+ x_list, branges, residual_list, residual_scale_factors
+ ):
+ outputs.append(
+ add_residual(
+ x, brange, residual, residual_scale_factor, scaling_vector
+ ).view_as(x)
+ )
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(
+ self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
+ ),
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(
+ self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
+ ),
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, torch.Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert (
+ XFORMERS_AVAILABLE
+ ), "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/src/da2/model/dinov2/dino_head.py b/src/da2/model/dinov2/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b
--- /dev/null
+++ b/src/da2/model/dinov2/dino_head.py
@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(
+ nlayers,
+ in_dim,
+ bottleneck_dim,
+ hidden_dim=hidden_dim,
+ use_bn=use_bn,
+ bias=mlp_bias,
+ )
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
+):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/src/da2/model/dinov2/dinovit.py b/src/da2/model/dinov2/dinovit.py
new file mode 100644
index 0000000000000000000000000000000000000000..04e679d67ec8024c23f788d225656b621165825b
--- /dev/null
+++ b/src/da2/model/dinov2/dinovit.py
@@ -0,0 +1,223 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+import math
+import torch
+import torch.nn as nn
+import contextlib
+from functools import partial
+from typing import Sequence
+from .block import (
+ Block
+)
+from .attention import (
+ MemEffAttention
+)
+from .mlp import (
+ Mlp
+)
+from .patch_embed import (
+ PatchEmbed
+)
+from .swiglu_ffn import (
+ SwiGLUFFNFused
+)
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha, memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class DINOViT(nn.Module):
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ in_chans=3,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=1.0,
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ ffn_layer="mlp",
+ block_chunks=0,
+ output_idx=[6, 12, 18, 24],
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ use_norm=True,
+ frozen_stages=0,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ self.frozen_stages = frozen_stages
+ self.patch_size = patch_size
+ self.output_idx = output_idx
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + 1, embed_dim)
+ )
+ assert num_register_tokens >= 0
+ self.register_tokens = nn.Parameter(
+ torch.zeros(1, max(1, num_register_tokens), embed_dim)
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ]
+
+ if ffn_layer == "mlp":
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ def f():
+ return nn.Identity()
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=nn.LayerNorm,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = nn.LayerNorm(embed_dim)
+ self.use_norm = use_norm
+ self.head = nn.Identity()
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ def interpolate_pos_encoding(self, x, W, H):
+ previous_dtype = x.dtype
+ N = self.pos_embed.shape[1] - 1
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = W // self.patch_size
+ h0 = H // self.patch_size
+
+ M = int(math.sqrt(N))
+ assert N == M * M
+ kwargs = {}
+ kwargs["size"] = (w0, h0)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+ previous_dtype
+ )
+
+ def tokenize(self, x):
+ _, _, W, H = x.shape
+ with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext():
+ x = self.patch_embed(x)
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ dino_pos_embed = self.interpolate_pos_encoding(x, W, H)
+ x = x + dino_pos_embed
+ return x
+
+ def forward_features(self, x):
+ shapes = [val // self.patch_size for val in x.shape[-2:]]
+ batch_size = x.shape[0]
+ features = []
+ x = self.tokenize(x)
+ for i, blk in enumerate(self.blocks):
+ with (
+ torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext()
+ ):
+ x = blk(x)
+ features.append(x)
+ if self.use_norm:
+ with (
+ torch.no_grad()
+ if self.frozen_stages >= len(self.blocks)
+ else contextlib.nullcontext()
+ ):
+ features = [self.norm(out) for out in features]
+ features = [out[:, self.num_register_tokens + 1 :] for out in features]
+ features = [out.reshape(batch_size, *shapes, -1) for out in features]
+ return features
+
+ def forward(self, *args):
+ features = self.forward_features(*args)
+ return features
diff --git a/src/da2/model/dinov2/drop_path.py b/src/da2/model/dinov2/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93
--- /dev/null
+++ b/src/da2/model/dinov2/drop_path.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+import torch.nn as nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/src/da2/model/dinov2/layer_scale.py b/src/da2/model/dinov2/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..40d18b5427183534d5516652b076f9883a609fc6
--- /dev/null
+++ b/src/da2/model/dinov2/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/src/da2/model/dinov2/mlp.py b/src/da2/model/dinov2/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..af598999855d897948142cc986fce82abc9e3b53
--- /dev/null
+++ b/src/da2/model/dinov2/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/src/da2/model/dinov2/patch_embed.py b/src/da2/model/dinov2/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a56c02609e67922eb8f859588ef274e5298b55
--- /dev/null
+++ b/src/da2/model/dinov2/patch_embed.py
@@ -0,0 +1,101 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch.nn as nn
+from torch import Tensor
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert (
+ H % patch_H == 0
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert (
+ W % patch_W == 0
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = (
+ Ho
+ * Wo
+ * self.embed_dim
+ * self.in_chans
+ * (self.patch_size[0] * self.patch_size[1])
+ )
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/src/da2/model/dinov2/swiglu_ffn.py b/src/da2/model/dinov2/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e82999e9b09b41cd6aba9edbc4c05d51ab663a1e
--- /dev/null
+++ b/src/da2/model/dinov2/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/src/da2/model/sphere.py b/src/da2/model/sphere.py
new file mode 100644
index 0000000000000000000000000000000000000000..a969bdfffe707052793b9c399f9fdc6f54b96102
--- /dev/null
+++ b/src/da2/model/sphere.py
@@ -0,0 +1,30 @@
+import torch
+
+
+def get_uv_gird(h, w, device):
+ pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device)
+ pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device)
+ stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]
+ grid = torch.stack(stacks, dim=0).float()
+ grid = grid.to(device).unsqueeze(0)
+ return grid
+
+
+class Sphere():
+ def __init__(self, config, device):
+ self.config = config
+ self.device = device
+
+ def get_directions(self, shape):
+ h, w = shape
+ uv = get_uv_gird(h, w, device=self.device)
+ u, v = uv.unbind(dim=1)
+ width, height = self.config['width'], self.config['height']
+ hfov, vfov = self.config['hfov'], self.config['vfov']
+ longitude = (u - width / 2) / width * hfov
+ latitude = (v - height / 2) / height * vfov
+ x = torch.cos(latitude) * torch.sin(longitude)
+ z = torch.cos(latitude) * torch.cos(longitude)
+ y = torch.sin(latitude)
+ sphere_directions = torch.stack([x, y, z], dim=1)
+ return sphere_directions
diff --git a/src/da2/model/spherevit.py b/src/da2/model/spherevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6927e4944eade89d108f50f4bb14f1ef6ed5aae1
--- /dev/null
+++ b/src/da2/model/spherevit.py
@@ -0,0 +1,69 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+from math import (
+ ceil,
+ sqrt
+)
+from huggingface_hub import PyTorchModelHubMixin
+import torchvision.transforms.v2.functional as TF
+from .dinov2 import DINOViT
+from .vit_w_esphere import ViT_w_Esphere
+from .sphere import Sphere
+
+
+IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
+
+class SphereViT(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.dino = DINOViT()
+ self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere'])
+ feature_slices = self.dino.output_idx
+ self.feature_slices = list(
+ zip([0, *feature_slices[:-1]], feature_slices)
+ )
+ self.device = None
+
+ def to(self, *args):
+ self.device = args[0]
+ return super().to(*args)
+
+ def forward(self, images):
+ B, _, H, W = images.shape
+ current_pixels = H * W
+ target_pixels = min(self.config['inference']['max_pixels'],
+ max(self.config['inference']['min_pixels'], current_pixels))
+ factor = sqrt(target_pixels / current_pixels)
+ sphere_config = deepcopy(self.config['spherevit']['sphere'])
+ sphere_config['width'] *= factor
+ sphere_config['height'] *= factor
+ sphere = Sphere(config=sphere_config, device=self.device)
+ H_new = int(H * factor)
+ W_new = int(W * factor)
+ DINO_patch_size = 14 # please see the line 51 of `src/da2/model/dinov2/dinovit.py` (I know it's a little ugly to hardcode it here T_T)
+ H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size
+ W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size
+ images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False)
+ images = TF.normalize(
+ images.float(),
+ mean=IMAGENET_DATASET_MEAN,
+ std=IMAGENET_DATASET_STD,
+ )
+
+ sphere_dirs = sphere.get_directions(shape=(H_new, W_new))
+ sphere_dirs = sphere_dirs.to(self.device)
+ sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1)
+
+ features = self.dino(images)
+ features = [
+ features[i:j][-1].contiguous()
+ for i, j in self.feature_slices
+ ]
+ distance = self.vit_w_esphere(images, features, sphere_dirs)
+ distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False)
+ distance = distance.squeeze(dim=1) # (b, 1, h, w) -> (b, h, w)
+ return distance
diff --git a/src/da2/model/vit_w_esphere.py b/src/da2/model/vit_w_esphere.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb6e34f2aeee0c6fba48c8716c4218c50b3d0135
--- /dev/null
+++ b/src/da2/model/vit_w_esphere.py
@@ -0,0 +1,224 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from .base import (
+ fourier_dimension_expansion,
+ flatten,
+ DimensionAligner,
+ AttentionSeq,
+ ResidualUpsampler
+)
+
+
+class _ViT_w_Esphere(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int = 8,
+ expansion: int = 4,
+ num_layers_head: int | list[int] = 4,
+ dropout: float = 0.0,
+ kernel_size: int = 7,
+ layer_scale: float = 1.0,
+ out_dim: int = 1,
+ num_prompt_blocks: int = 1,
+ use_norm: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.out_dim = out_dim
+ self.hidden_dim = hidden_dim
+ self.up_sampler = nn.ModuleList([])
+ self.pred_head = nn.ModuleList([])
+ self.process_features = nn.ModuleList([])
+ self.prompt_camera = nn.ModuleList([])
+ mult = 2
+ self.to_latents = nn.Linear(hidden_dim, hidden_dim)
+
+ for _ in range(4):
+ self.prompt_camera.append(
+ AttentionSeq(
+ num_blocks=num_prompt_blocks,
+ dim=hidden_dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ layer_scale=-1.0,
+ context_dim=hidden_dim,
+ )
+ )
+
+ for i, depth in enumerate(num_layers_head):
+ current_dim = min(hidden_dim, mult * hidden_dim // int(2**i))
+ next_dim = mult * hidden_dim // int(2 ** (i + 1))
+ output_dim = max(next_dim, out_dim)
+ self.process_features.append(
+ nn.ConvTranspose2d(
+ hidden_dim,
+ current_dim,
+ kernel_size=max(1, 2 * i),
+ stride=max(1, 2 * i),
+ padding=0,
+ )
+ )
+ self.up_sampler.append(
+ ResidualUpsampler(
+ current_dim,
+ output_dim=output_dim,
+ expansion=expansion,
+ layer_scale=layer_scale,
+ kernel_size=kernel_size,
+ num_layers=depth,
+ use_norm=use_norm,
+ )
+ )
+ pred_head = (
+ nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim))
+ if i == len(num_layers_head) - 1
+ else nn.Identity()
+ )
+ self.pred_head.append(pred_head)
+
+ self.to_depth_lr = nn.Conv2d(
+ output_dim,
+ output_dim // 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode='reflect',
+ )
+ self.to_confidence_lr = nn.Conv2d(
+ output_dim,
+ output_dim // 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode='reflect',
+ )
+ self.to_depth_hr = nn.Sequential(
+ nn.Conv2d(
+ output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
+ ),
+ nn.LeakyReLU(),
+ nn.Conv2d(32, 1, kernel_size=1),
+ )
+ self.to_confidence_hr = nn.Sequential(
+ nn.Conv2d(
+ output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
+ ),
+ nn.LeakyReLU(),
+ nn.Conv2d(32, 1, kernel_size=1),
+ )
+
+ def set_original_shapes(self, shapes: tuple[int, int]):
+ self.original_shapes = shapes
+
+ def set_shapes(self, shapes: tuple[int, int]):
+ self.shapes = shapes
+
+ def embed_sphere_dirs(self, sphere_dirs):
+ sphere_embedding = flatten(
+ sphere_dirs, old=self.original_shapes, new=self.shapes
+ )
+ # index 0 -> Y
+ # index 1 -> Z
+ # index 2 -> X
+ r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2]
+ polar = torch.asin(r2)
+ r3_clipped = r3.abs().clip(min=1e-5) * (2 * (r3 >= 0).int() - 1)
+ azimuth = torch.atan2(r1, r3_clipped)
+ # [polar, azimuth] is the angle field
+ sphere_embedding = torch.stack([polar, azimuth], dim=-1)
+ # expand the dimension of the angle field to image feature dimensions, via sine-cosine basis embedding
+ sphere_embedding = fourier_dimension_expansion(
+ sphere_embedding,
+ dim=self.hidden_dim,
+ max_freq=max(self.shapes) // 2,
+ use_cos=False,
+ )
+ return sphere_embedding
+
+ def condition(self, feat, sphere_embeddings):
+ conditioned_features = [
+ prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings)
+ for prompter, feature in zip(self.prompt_camera, feat)
+ ]
+ return conditioned_features
+
+ def process(self, features_list, sphere_embeddings):
+ conditioned_features = self.condition(features_list, sphere_embeddings)
+ init_latents = self.to_latents(conditioned_features[0])
+ init_latents = rearrange(
+ init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
+ ).contiguous()
+ conditioned_features = [
+ rearrange(
+ x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
+ ).contiguous()
+ for x in conditioned_features
+ ]
+ latents = init_latents
+
+ out_features = []
+ # Pyramid-like multi-layer convolutional feature extraction
+ for i, up in enumerate(self.up_sampler):
+ latents = latents + self.process_features[i](conditioned_features[i + 1])
+ latents = up(latents)
+ out_features.append(latents)
+ return out_features
+
+ def prediction_head(self, out_features):
+ depths = []
+ h_out, w_out = out_features[-1].shape[-2:]
+ for i, (layer, features) in enumerate(zip(self.pred_head, out_features)):
+ out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ if i < len(self.pred_head) - 1:
+ continue
+ depths.append(out_depth_features)
+ out_depth_features = F.interpolate(
+ out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True
+ )
+ distance = self.to_depth_lr(out_depth_features)
+ distance = F.interpolate(
+ distance, size=self.original_shapes, mode='bilinear', align_corners=True
+ )
+ distance = self.to_depth_hr(distance)
+ return distance
+
+ def forward(
+ self,
+ features: list[torch.Tensor],
+ sphere_dirs: torch.Tensor
+ ) -> torch.Tensor:
+ sphere_embeddings = self.embed_sphere_dirs(sphere_dirs)
+ features = self.process(features, sphere_embeddings)
+ distance = self.prediction_head(features)
+ return distance
+
+
+class ViT_w_Esphere(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.dim_aligner = DimensionAligner(
+ input_dims=config['input_dims'],
+ hidden_dim=config['hidden_dim'],
+ )
+ self._vit_w_esphere = _ViT_w_Esphere(**config)
+
+ def forward(self, images, features, sphere_dirs) -> torch.Tensor:
+ _, _, H, W = images.shape
+ sphere_dirs = sphere_dirs
+ common_shape = features[0].shape[1:3]
+ features = self.dim_aligner(features)
+ sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c')
+
+ self._vit_w_esphere.set_shapes(common_shape)
+ self._vit_w_esphere.set_original_shapes((H, W))
+ logdistance = self._vit_w_esphere(
+ features=features,
+ sphere_dirs=sphere_dirs,
+ )
+
+ distance = torch.exp(logdistance.clip(min=-8.0, max=8.0) + 2.0)
+ distance = distance / torch.quantile(distance, 0.98)
+ return distance
diff --git a/src/da2/utils/__init__.py b/src/da2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac61fb325760b9f30f6444fc386d07b9c23022f5
--- /dev/null
+++ b/src/da2/utils/__init__.py
@@ -0,0 +1,11 @@
+from .base import (
+ prepare_to_run
+)
+from .model import (
+ load_model
+)
+
+__all__ = [
+ 'prepare_to_run',
+ 'load_model'
+]
diff --git a/src/da2/utils/__pycache__/__init__.cpython-312.pyc b/src/da2/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c636b81d1f058165fb6316e7109a4f77b696e7b9
Binary files /dev/null and b/src/da2/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/src/da2/utils/__pycache__/base.cpython-312.pyc b/src/da2/utils/__pycache__/base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c74a8a9801068a53cc8d00100da6c1ef698cb381
Binary files /dev/null and b/src/da2/utils/__pycache__/base.cpython-312.pyc differ
diff --git a/src/da2/utils/__pycache__/d2pc.cpython-312.pyc b/src/da2/utils/__pycache__/d2pc.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70fd92a61e91bd241a3230635c36ff95f86b1adf
Binary files /dev/null and b/src/da2/utils/__pycache__/d2pc.cpython-312.pyc differ
diff --git a/src/da2/utils/__pycache__/io.cpython-312.pyc b/src/da2/utils/__pycache__/io.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45ccb3ac4d2a23f52c6dd4fb97c2a13b4a4b259e
Binary files /dev/null and b/src/da2/utils/__pycache__/io.cpython-312.pyc differ
diff --git a/src/da2/utils/__pycache__/model.cpython-312.pyc b/src/da2/utils/__pycache__/model.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47e706b0ec91363b1b3ea0f68f0065318f36ffad
Binary files /dev/null and b/src/da2/utils/__pycache__/model.cpython-312.pyc differ
diff --git a/src/da2/utils/__pycache__/vis.cpython-312.pyc b/src/da2/utils/__pycache__/vis.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bebaecbb4e809ca9277277597b264c78623a2d35
Binary files /dev/null and b/src/da2/utils/__pycache__/vis.cpython-312.pyc differ
diff --git a/src/da2/utils/base.py b/src/da2/utils/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..148b261d5646947592c6cb00ed96eb8955c5bea6
--- /dev/null
+++ b/src/da2/utils/base.py
@@ -0,0 +1,56 @@
+import json
+import argparse
+import os
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import (
+ InitProcessGroupKwargs,
+ ProjectConfiguration,
+ set_seed
+)
+import logging
+from datetime import (
+ timedelta,
+ datetime
+)
+
+
+def load_config(config_path):
+ with open(config_path, 'r') as f:
+ config = json.load(f)
+ return config
+
+def arg_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config_path', type=str, required=True)
+ args = parser.parse_args()
+ return args
+
+def prepare_to_run():
+ args = arg_parser()
+ logging.basicConfig(
+ format='%(asctime)s --> %(message)s',
+ datefmt='%m/%d %H:%M:%S',
+ level=logging.INFO,
+ )
+ config = load_config(args.config_path)
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
+ version = os.path.basename(args.config_path)[:-5]
+ output_dir = f'output/{version}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
+ if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
+ accu_steps = config['accelerator']['accumulation_nsteps']
+ accelerator = Accelerator(
+ gradient_accumulation_steps=accu_steps,
+ mixed_precision=config['accelerator']['mixed_precision'],
+ log_with=config['accelerator']['report_to'],
+ project_config=ProjectConfiguration(project_dir=output_dir),
+ kwargs_handlers=[kwargs]
+ )
+ logger = get_logger(__name__, log_level='INFO')
+ config['env']['logger'] = logger
+ set_seed(config['env']['seed'])
+ if config['env']['verbose']:
+ logger.info(f'Version: {version} (from {args.config_path})')
+ logger.info(f'Output dir: {output_dir}')
+ logger.info(f'Using {accelerator.num_processes} GPU' + ('s' if accelerator.num_processes > 1 else ''))
+ return config, accelerator, output_dir
diff --git a/src/da2/utils/d2pc.py b/src/da2/utils/d2pc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e67c8f4c4387cfde4aac4a791428e4ee27df3742
--- /dev/null
+++ b/src/da2/utils/d2pc.py
@@ -0,0 +1,76 @@
+import os
+import numpy as np
+import utils3d
+from plyfile import PlyData, PlyElement
+from PIL import Image
+
+
+def sphere_uv2dirs(uv: np.ndarray):
+ theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
+ directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
+ return directions
+
+def save_3d_points(points: np.array, colors: np.array, mask: np.array, save_path: str):
+ points = points.reshape(-1, 3)
+ colors = colors.reshape(-1, 3)
+ mask = mask.reshape(-1, 1)
+
+ vertex_data = np.empty(mask.sum(), dtype=[
+ ('x', 'f4'),
+ ('y', 'f4'),
+ ('z', 'f4'),
+ ('red', 'u1'),
+ ('green', 'u1'),
+ ('blue', 'u1'),
+ ])
+ vertex_data['x'] = [a for i, a in enumerate(points[:, 0]) if mask[i]]
+ vertex_data['y'] = [a for i, a in enumerate(points[:, 1]) if mask[i]]
+ vertex_data['z'] = [a for i, a in enumerate(points[:, 2]) if mask[i]]
+ vertex_data['red'] = [a for i, a in enumerate(colors[:, 0]) if mask[i]]
+ vertex_data['green'] = [a for i, a in enumerate(colors[:, 1]) if mask[i]]
+ vertex_data['blue'] = [a for i, a in enumerate(colors[:, 2]) if mask[i]]
+
+ vertex_element = PlyElement.describe(vertex_data, 'vertex', comments=['vertices with color'])
+ save_dir = os.path.dirname(save_path)
+ if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
+ PlyData([vertex_element], text=True).write(save_path)
+
+def colorize_normal(normal: np.ndarray, normal_mask: np.ndarray):
+ normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
+ normal_mask = np.repeat(np.expand_dims(normal_mask, axis=-1), 3, axis=-1)
+ normal_mask = normal_mask.astype(np.uint8)
+ normal_rgb = normal_rgb * normal_mask
+ return normal_rgb
+
+def normal_normalize(normal: np.ndarray):
+ normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
+ normal_norm[normal_norm < 1e-6] = 1e-6
+ return normal / normal_norm
+
+def distance2pointcloud(
+ distance: np.ndarray,
+ image: np.ndarray,
+ mask: np.ndarray,
+ save_path: str = None,
+ return_normal: bool = False,
+ save_distance: bool = False
+):
+ if distance.ndim >= 3: distance = distance.squeeze()
+ if save_distance:
+ save_path_dis = save_path.replace('3dpc', 'depth').replace('.ply', '.npy')
+ save_dir_dis = os.path.dirname(save_path_dis)
+ if not os.path.exists(save_dir_dis): os.makedirs(save_dir_dis, exist_ok=True)
+ np.save(save_path_dis, distance)
+ height, width = distance.shape[:2]
+ points = distance[:, :, None] * sphere_uv2dirs(utils3d.numpy.image_uv(width=width, height=height))
+ save_3d_points(points, image, mask, save_path)
+ if return_normal:
+ normal, normal_mask = utils3d.numpy.points_to_normals(points, mask)
+ normal = normal * np.array([-1, -1, 1])
+ normal = normal_normalize(normal)
+ normal_1 = normal[..., 0]
+ normal_2 = normal[..., 1]
+ normal_3 = normal[..., 2]
+ normal = np.stack([normal_1, normal_3, normal_2], axis=-1)
+ normal_img = colorize_normal(normal, normal_mask)
+ return Image.fromarray(normal_img)
diff --git a/src/da2/utils/io.py b/src/da2/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccaf2dc0034db8ae04d02f6c6a059b98827d1d4c
--- /dev/null
+++ b/src/da2/utils/io.py
@@ -0,0 +1,63 @@
+import os
+import cv2
+import torch
+import numpy as np
+from glob import glob
+from PIL import Image
+
+
+def torch_transform(image):
+ image = image / 255.0
+ image = np.transpose(image, (2, 0, 1))
+ return image
+
+def read_cv2_image(image_path):
+ image = cv2.imread(image_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ return image
+
+def read_mask(mask_path, shape):
+ if not os.path.exists(mask_path):
+ return np.ones(shape[1:]) > 0
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+ mask = mask > 0
+ return mask
+
+def tensorize(array, model_dtype, device):
+ array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0)
+ return array
+
+def load_infer_data(config, device):
+ image_dir = config['inference']['images']
+ mask_dir = config['inference']['masks']
+
+ image_paths = glob(os.path.join(image_dir, '*.png'))
+ image_paths = sorted(image_paths)
+ filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths]
+ cv2_images = [read_cv2_image(image_path)
+ for image_path in image_paths]
+ PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images]
+ images = [torch_transform(cv2_image) for cv2_image in cv2_images]
+
+ mask_paths = [image_path.replace(image_dir, mask_dir)
+ for image_path in image_paths]
+ masks = [read_mask(mask_path, images[i].shape)
+ for (i, mask_path) in enumerate(mask_paths)]
+
+ model_dtype = config['spherevit']['dtype']
+ images = [tensorize(image, model_dtype, device) for image in images]
+
+ infer_data = {
+ 'images': {
+ 'PIL': PIL_images,
+ 'cv2': cv2_images,
+ 'torch': images
+ },
+ 'masks': masks,
+ 'filenames': filenames,
+ 'size': len(images)
+ }
+ if config['env']['verbose']:
+ s = 's' if len(images) > 1 else ''
+ config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}')
+ return infer_data
diff --git a/src/da2/utils/model.py b/src/da2/utils/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0088ab1e0a1dd995767ea8c107a40b6fb1c2e836
--- /dev/null
+++ b/src/da2/utils/model.py
@@ -0,0 +1,15 @@
+import torch
+from ..model.spherevit import SphereViT
+
+
+def load_model(config, accelerator):
+ model = SphereViT.from_pretrained('haodongli/DA-2', config=config)
+ model = model.to(accelerator.device)
+ torch.cuda.empty_cache()
+ model = accelerator.prepare(model)
+ if accelerator.num_processes > 1:
+ model = model.module
+ if config['env']['verbose']:
+ config['env']['logger'].info(f'Model\'s dtype: {next(model.parameters()).dtype}.')
+ config['spherevit']['dtype'] = next(model.parameters()).dtype
+ return model
diff --git a/src/da2/utils/vis.py b/src/da2/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cbeab57788963cd0b7ff678979f8723d80911b9
--- /dev/null
+++ b/src/da2/utils/vis.py
@@ -0,0 +1,44 @@
+import torch
+from PIL import Image
+import numpy as np
+import matplotlib
+import cv2
+
+
+def concatenate_images(*image_lists):
+ max_width = 0
+ total_height = 0
+ row_widths = []
+ row_heights = []
+
+ for i, image_list in enumerate(image_lists):
+ width = sum(img.width for img in image_list)
+ max_width = max(max_width, width)
+ row_widths.append(width)
+ # Assuming all images in the list have the same height
+ height = image_list[0].height
+ total_height += height
+ row_heights.append(height)
+
+ new_image = Image.new('RGB', (max_width, total_height))
+ y_offset = 0
+ for i, image_list in enumerate(image_lists):
+ x_offset = 0
+ for img in image_list:
+ new_image.paste(img, (x_offset, y_offset))
+ x_offset += img.width
+ y_offset += row_heights[i]
+ return new_image
+
+def colorize_distance(distance, mask, cmap='Spectral'):
+ if distance.ndim >= 3: distance = distance.squeeze()
+ cm = matplotlib.colormaps[cmap]
+ valid_distance = distance[mask]
+ max_distance = np.quantile(valid_distance, 0.98)
+ min_distance = np.quantile(valid_distance, 0.02)
+ distance[~mask] = max_distance
+ distance = ((distance - min_distance) / (max_distance - min_distance))
+ distance = np.clip(distance, 0, 1)
+ img_colored_np = cm(distance, bytes=False)[:, :, 0:3]
+ distance_colored = (img_colored_np * 255).astype(np.uint8)
+ return Image.fromarray(distance_colored)
diff --git a/src/pyproject.toml b/src/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..933ef0b55c7b60e158160965754924062a83eacd
--- /dev/null
+++ b/src/pyproject.toml
@@ -0,0 +1,35 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "da2"
+version = "0.1.0"
+description = "For the implementation of DA^2: Depth Anything in Any Direction"
+authors = [
+ { name="H. Li", email="hal211@ucsd.edu" }
+]
+dependencies = [
+ "torch==2.5.0",
+ "torchvision==0.20.0",
+ "torchaudio==2.5.0",
+ "xformers==0.0.28.post2",
+ "diffusers==0.32.0",
+ "tensorboard==2.18.0",
+ "utils3d @ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900",
+ "opencv-python==4.12.0.88",
+ "gradio==5.49.0",
+ "gradio-client==1.13.3",
+ "gradio-imageslider==0.0.20",
+ "accelerate==1.1.1",
+ "omegaconf==2.3.0",
+ "tabulate==0.9.0",
+ "einops==0.8.0",
+ "timm==1.0.15",
+ "trimesh==4.5.2",
+ "transformers==4.46.3",
+ "matplotlib==3.9.2"
+]
+
+[tool.setuptools.packages.find]
+where = ["."]
\ No newline at end of file