import torch from huggingface_hub import hf_hub_download from .bisenet import BiSeNet from .parsenet import ParseNet REPO_ID = "leonelhs/facexlib" def init_parsing_model(model_name='bisenet', half=False, device='cuda'): if model_name == 'bisenet': model = BiSeNet(num_class=19) model_path = hf_hub_download(repo_id=REPO_ID, filename='parsing_bisenet.pth') elif model_name == 'parsenet': model = ParseNet(in_size=512, out_size=512, parsing_ch=19) model_path = hf_hub_download(repo_id=REPO_ID, filename='parsing_parsenet.pth') else: raise NotImplementedError(f'{model_name} is not implemented.') load_net = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(load_net, strict=True) model.eval() model = model.to(device) return model