| | """ |
| | Example usage of CosmicFish model (using safetensors) |
| | """ |
| | import torch |
| | from transformers import GPT2Tokenizer |
| | from modeling_cosmicfish import CosmicFish, CosmicConfig |
| | from safetensors.torch import load_file |
| | import json |
| |
|
| | def load_cosmicfish(model_dir): |
| | """Load CosmicFish model and tokenizer""" |
| | |
| | with open(f"{model_dir}/config.json", "r") as f: |
| | config_dict = json.load(f) |
| |
|
| | |
| | config = CosmicConfig( |
| | vocab_size=config_dict["vocab_size"], |
| | block_size=config_dict["block_size"], |
| | n_layer=config_dict["n_layer"], |
| | n_head=config_dict["n_head"], |
| | n_embd=config_dict["n_embd"], |
| | bias=config_dict["bias"], |
| | dropout=0.0, |
| | use_rotary=config_dict["use_rotary"], |
| | use_swiglu=config_dict["use_swiglu"], |
| | use_gqa=config_dict["use_gqa"], |
| | n_query_groups=config_dict["n_query_groups"], |
| | use_qk_norm=config_dict["use_qk_norm"] |
| | ) |
| |
|
| | |
| | model = CosmicFish(config) |
| |
|
| | |
| | state_dict = load_file(f"{model_dir}/model.safetensors") |
| |
|
| | |
| | if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
| | print("Weight sharing detected: tying lm_head.weight to transformer.wte.weight") |
| | state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
| |
|
| | model.load_state_dict(state_dict) |
| | model.eval() |
| |
|
| | |
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| |
|
| | return model, tokenizer |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|