File size: 546 Bytes
d82e7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from ..model.spherevit import SphereViT


def load_model(config, accelerator):
    model = SphereViT.from_pretrained('haodongli/DA-2', config=config)
    model = model.to(accelerator.device)
    torch.cuda.empty_cache()
    model = accelerator.prepare(model)
    if accelerator.num_processes > 1:
        model = model.module
    if config['env']['verbose']:
        config['env']['logger'].info(f'Model\'s dtype: {next(model.parameters()).dtype}.')
    config['spherevit']['dtype'] = next(model.parameters()).dtype
    return model