Spaces:
Running
on
Zero
Running
on
Zero
| import os, torch | |
| from typing import List, Tuple, Optional, Union, Dict | |
| from .ebc import _ebc, EBC | |
| from .clip_ebc import _clip_ebc, CLIP_EBC | |
| def get_model( | |
| model_info_path: str, | |
| model_name: Optional[str] = None, | |
| block_size: Optional[int] = None, | |
| bins: Optional[List[Tuple[float, float]]] = None, | |
| bin_centers: Optional[List[float]] = None, | |
| zero_inflated: Optional[bool] = True, | |
| # parameters for CLIP_EBC | |
| clip_weight_name: Optional[str] = None, | |
| num_vpt: Optional[int] = None, | |
| vpt_drop: Optional[float] = None, | |
| input_size: Optional[int] = None, | |
| norm: str = "none", | |
| act: str = "none", | |
| text_prompts: Optional[List[str]] = None | |
| ) -> Union[EBC, CLIP_EBC]: | |
| if os.path.exists(model_info_path): | |
| model_info = torch.load(model_info_path, map_location="cpu", weights_only=False) | |
| model_name = model_info["config"]["model_name"] | |
| block_size = model_info["config"]["block_size"] | |
| bins = model_info["config"]["bins"] | |
| bin_centers = model_info["config"]["bin_centers"] | |
| zero_inflated = model_info["config"]["zero_inflated"] | |
| clip_weight_name = model_info["config"].get("clip_weight_name", None) | |
| num_vpt = model_info["config"].get("num_vpt", None) | |
| vpt_drop = model_info["config"].get("vpt_drop", None) | |
| input_size = model_info["config"].get("input_size", None) | |
| text_prompts = model_info["config"].get("text_prompts", None) | |
| norm = model_info["config"].get("norm", "none") | |
| act = model_info["config"].get("act", "none") | |
| weights = model_info["weights"] | |
| else: | |
| assert model_name is not None, "model_name should be provided if model_info_path is not provided" | |
| assert block_size is not None, "block_size should be provided" | |
| assert bins is not None, "bins should be provided" | |
| assert bin_centers is not None, "bin_centers should be provided" | |
| weights = None | |
| if "ViT" in model_name: | |
| assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}" | |
| assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}" | |
| if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"): | |
| assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}" | |
| model = _clip_ebc( | |
| model_name=model_name[5:], | |
| weight_name=clip_weight_name, | |
| block_size=block_size, | |
| bins=bins, | |
| bin_centers=bin_centers, | |
| zero_inflated=zero_inflated, | |
| num_vpt=num_vpt, | |
| vpt_drop=vpt_drop, | |
| input_size=input_size, | |
| text_prompts=text_prompts, | |
| norm=norm, | |
| act=act | |
| ) | |
| model_config = { | |
| "model_name": model_name, | |
| "block_size": block_size, | |
| "bins": bins, | |
| "bin_centers": bin_centers, | |
| "zero_inflated": zero_inflated, | |
| "clip_weight_name": clip_weight_name, | |
| "num_vpt": num_vpt, | |
| "vpt_drop": vpt_drop, | |
| "input_size": input_size, | |
| "text_prompts": model.text_prompts, | |
| "norm": norm, | |
| "act": act | |
| } | |
| else: | |
| model = _ebc( | |
| model_name=model_name, | |
| block_size=block_size, | |
| bins=bins, | |
| bin_centers=bin_centers, | |
| zero_inflated=zero_inflated, | |
| num_vpt=num_vpt, | |
| vpt_drop=vpt_drop, | |
| input_size=input_size, | |
| norm=norm, | |
| act=act | |
| ) | |
| model_config = { | |
| "model_name": model_name, | |
| "block_size": block_size, | |
| "bins": bins, | |
| "bin_centers": bin_centers, | |
| "zero_inflated": zero_inflated, | |
| "num_vpt": num_vpt, | |
| "vpt_drop": vpt_drop, | |
| "input_size": input_size, | |
| "norm": norm, | |
| "act": act | |
| } | |
| model.config = model_config | |
| model_info = {"config": model_config, "weights": weights} | |
| if weights is not None: | |
| model.load_state_dict(weights) | |
| if not os.path.exists(model_info_path): | |
| torch.save(model_info, model_info_path) | |
| return model | |
| __all__ = ["get_model"] | |