| | import os |
| | import torch |
| | import torch.nn.functional as F |
| | from collections import OrderedDict |
| | import string |
| | from model import ChatGCLM, MAX_SEQ_LEN |
| |
|
| | MODEL_PATH = None |
| | for f in os.listdir("."): |
| | if f.startswith("Turing_") and f.endswith(".pt"): |
| | MODEL_PATH = f |
| | break |
| |
|
| | if MODEL_PATH is None: |
| | print("Error: No model checkpoint found!") |
| | print("Please train the model first with: python3 train.py") |
| | exit(1) |
| |
|
| | EOS_ID = 2 |
| | OFFSET = 3 |
| | CHARS = string.printable |
| |
|
| | def encode(text): |
| | return [CHARS.index(c) + OFFSET for c in text if c in CHARS] |
| |
|
| | def decode(ids): |
| | return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) |
| |
|
| | def load_model(device): |
| | vocab_size = len(CHARS) + OFFSET |
| | |
| | model = ChatGCLM(vocab_size).to(device) |
| | if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: |
| | print(f"Loading model from: {MODEL_PATH}") |
| | ckpt = torch.load(MODEL_PATH, map_location=device) |
| |
|
| | if isinstance(ckpt, dict): |
| | if 'model_state_dict' in ckpt: |
| | state_dict = ckpt['model_state_dict'] |
| | elif 'state_dict' in ckpt: |
| | state_dict = ckpt['state_dict'] |
| | else: |
| | state_dict = ckpt |
| | else: |
| | state_dict = ckpt |
| |
|
| | def _strip_module_prefix(sd): |
| | keys = list(sd.keys()) |
| | if any(k.startswith('module.') for k in keys): |
| | new_sd = OrderedDict() |
| | for k, v in sd.items(): |
| | new_key = k[len('module.'): ] if k.startswith('module.') else k |
| | new_sd[new_key] = v |
| | return new_sd |
| | return sd |
| |
|
| | state_dict = _strip_module_prefix(state_dict) |
| |
|
| | res = model.load_state_dict(state_dict, strict=False) |
| | missing = getattr(res, 'missing_keys', None) |
| | unexpected = getattr(res, 'unexpected_keys', None) |
| | if missing: |
| | print(f"Warning: missing keys when loading state_dict: {missing}") |
| | if unexpected: |
| | print(f"Warning: unexpected keys in state_dict: {unexpected}") |
| |
|
| | model.eval() |
| | return model |
| | else: |
| | print(f"Error: Could not load model from {MODEL_PATH}") |
| | return None |
| |
|
| | @torch.no_grad() |
| | def generate(model, prompt, device, max_new_tokens=200, temperature=0.8, top_k=50): |
| | model.eval() |
| | input_ids = encode(prompt) |
| | x = torch.tensor([input_ids], dtype=torch.long, device=device) |
| | |
| | print(f"\n{'='*70}") |
| | print(f"PROMPT: {prompt}") |
| | print(f"{'='*70}") |
| | print("GENERATED TEXT:") |
| | print(prompt, end="", flush=True) |
| | |
| | generated_tokens = [] |
| | for _ in range(max_new_tokens): |
| | ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
| | logits = model(ctx) |
| | next_token_logits = logits[:, -1, :] / temperature |
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
| | next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
| | |
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | idx = next_token.item() |
| | |
| | if idx == EOS_ID: |
| | break |
| | |
| | x = torch.cat((x, next_token), dim=1) |
| | generated_tokens.append(idx) |
| | token_text = decode([idx]) |
| | print(token_text, end="", flush=True) |
| | |
| | print(f"\n{'='*70}\n") |
| | return decode(generated_tokens) |
| |
|
| | if __name__ == "__main__": |
| | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| | |
| | model = load_model(device) |
| | |
| | if model is None: |
| | exit(1) |
| | |
| | test_prompts = [ |
| | "Once upon a time", |
| | "The future of AI is", |
| | "In a world where", |
| | ] |
| | |
| | print("\n" + "="*70) |
| | print("ChatGCLM Text Generation Demo") |
| | print("="*70) |
| | |
| | for prompt in test_prompts: |
| | generate(model, prompt, device, max_new_tokens=150, temperature=0.8, top_k=50) |
| | |
| | print("\n" + "="*70) |
| | print("Interactive Mode - Enter your own prompts!") |
| | print("="*70) |
| | |
| | while True: |
| | user_prompt = input("\nEnter prompt (or 'exit' to quit): ") |
| | if user_prompt.lower() == 'exit': |
| | break |
| | if user_prompt.strip(): |
| | generate(model, user_prompt, device, max_new_tokens=200, temperature=0.8, top_k=50) |