Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| import torch | |
| tok = AutoTokenizer.from_pretrained("distilgpt2") | |
| model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count() | |
| model.to(device) | |
| def generate(text = "", max_new_tokens = 128): | |
| streamer = TextIteratorStreamer(tok, timeout=10.) | |
| if len(text) == 0: | |
| text = " " | |
| inputs = tok([text], return_tensors="pt").to(device) | |
| generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=max_new_tokens, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| yield generated_text + new_text | |
| generated_text += new_text | |
| if tok.eos_token in generated_text: | |
| generated_text = generated_text[: generated_text.find(tok.eos_token) if tok.eos_token else None] | |
| streamer.end() | |
| yield generated_text | |
| return | |
| return generated_text | |
| demo = gr.Interface( | |
| title="TextIteratorStreamer + Gradio demo", | |
| fn=generate, | |
| inputs=[gr.Textbox(lines=5, label="Input Text"), | |
| gr.Slider(value=128,minimum=5, maximum=256, step=1, label="Maximum number of new tokens")], | |
| outputs=gr.Textbox(label="Generated Text"), | |
| allow_flagging="never" | |
| ) | |
| demo.queue() | |
| demo.launch() |