Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from model import LlamaForCausalLM # Import your custom model class | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]" | |
| # Initialize model with reduced parameters (135M config) | |
| class Config: | |
| pass | |
| config = Config() | |
| config.vocab_size = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer").vocab_size | |
| config.num_layers = 30 | |
| config.hidden_size = 576 | |
| config.num_attention_heads = 8 | |
| config.rms_norm_eps = 1.0e-05 | |
| config.max_position_embeddings = 2048 | |
| config.rope_theta = 500000.0 | |
| config.hidden_act = False | |
| config.intermediate_size = 1536 | |
| config.rope_interleaved = False | |
| #config.rope_scaling = null | |
| config.rope_theta = 10000.0 | |
| model = LlamaForCausalLM(config) | |
| device = "cpu" | |
| model_id = "chbsaikiran/smollm2_135M_model" | |
| checkpoint_path = hf_hub_download(repo_id=model_id, filename="model_state_dict.pt") | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| def generate_text(prompt, max_length=100, temperature=0.7, top_k=50): | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| for _ in range(max_length): | |
| outputs = model(input_ids) | |
| next_token_logits = outputs[:, -1, :] / temperature | |
| # Apply top-k sampling | |
| top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) | |
| probs = torch.softmax(top_k_logits, dim=-1) | |
| # Sample from distribution | |
| next_token_idx = torch.multinomial(probs, num_samples=1) | |
| next_token = top_k_indices[0, next_token_idx[0]] | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) | |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(label="Input Prompt", lines=3), | |
| gr.Slider(20, 200, value=50, label="Max Length"), | |
| gr.Slider(0.1, 2.0, value=0.7, label="Temperature"), | |
| gr.Slider(10, 100, value=50, label="Top-k") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", lines=5), | |
| title="SmolLM2 Demo", | |
| description="A 135M parameter language model trained on Shakespeare's text" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |