Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import os | |
| import spaces | |
| import torch | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login, snapshot_download | |
| from superposed.llama.superposed_generation import SuperposedLlama | |
| from superposed.llama.tokenizer import Tokenizer | |
| from superposed.ngrams.ngram_models import make_models | |
| # load_dotenv() | |
| # print(os.getenv("HF_ACCESS_TOKEN")) | |
| login(os.getenv("HF_ACCESS_TOKEN")) | |
| if not os.path.exists("./weights/"): | |
| os.mkdir("./weights/") | |
| snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/") | |
| weight_path = "./weights/" | |
| # Load params | |
| param_file = "params/p15_d3_mixed.json" | |
| with open(param_file, "r") as f: | |
| params = json.load(f) | |
| alpha = params["alpha"] | |
| temp = params["temp"] | |
| n_drafts = params["n_drafts"] | |
| prompt_len = params["prompt_len"] | |
| n_token_sample = params["n_token_sample"] | |
| i_weights = params["i_weights"] | |
| i_length = params["i_length"] | |
| # Load main model | |
| model = SuperposedLlama.build(ckpt_dir=weight_path, | |
| tokenizer_path=f'{weight_path}/tokenizer.model', | |
| max_seq_len=100, | |
| max_batch_size=32, | |
| model_parallel_size=1) | |
| tokenizer = Tokenizer(f'{weight_path}/tokenizer.model') | |
| # Create ngram models | |
| ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False) | |
| def decode(tokenizer, encoding): | |
| """ | |
| Args: | |
| tokenizer (Any): Tokenizer | |
| encoding (torch.Tensor): Encoding | |
| Returns: | |
| decoding (str) | |
| """ | |
| eos_locs = (encoding == tokenizer.eos_id).nonzero() | |
| if len(eos_locs > 0): | |
| encoding = encoding[:eos_locs[0]] | |
| return tokenizer.decode(encoding.to(torch.int32).tolist()) | |
| def update_options(input, num_tokens): | |
| tokenized_prompts = tokenizer.encode([input], True, False) | |
| alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, | |
| smoothing="geom", | |
| max_gen_len=num_tokens, | |
| n_token_sample=n_token_sample, | |
| alpha=alpha, | |
| temp=temp, | |
| n_drafts=n_drafts, | |
| i_weights=i_weights, | |
| i_length=i_length, | |
| ngrams=ngrams, | |
| get_time=False, | |
| penalty=200) | |
| gens = alive_gens[0].reshape(n_drafts, -1) | |
| return decode(tokenizer, gens[0]), decode(tokenizer, gens[1]), decode(tokenizer, gens[2]) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Superposed Decoding | |
| Start typing below to see suggestions. | |
| """) | |
| slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10) | |
| inp = gr.Textbox(placeholder="Type anything!", lines=3) | |
| option1 = gr.Button(value="Option 1") | |
| option2 = gr.Button(value="Option 2") | |
| option3 = gr.Button(value="Option 3") | |
| inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3]) | |
| # Button updates | |
| def option1_click(curr, txt): | |
| return curr + txt | |
| def option2_click(curr, txt): | |
| return curr + txt | |
| def option3_click(curr, txt): | |
| return curr + txt | |
| if __name__ == "__main__": | |
| demo.launch() |