Spaces:
Running
Running
| import torch | |
| from transformers import T5Tokenizer, GPT2LMHeadModel | |
| from flask import Flask, request, jsonify | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium") | |
| model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium") | |
| model = model.to(device) | |
| def gen_lyric(title: str, prompt_text: str): | |
| if len(title) != 0 or len(prompt_text) != 0: | |
| prompt_text = "<s>" + title + "[CLS]" + prompt_text | |
| prompt_text = prompt_text.replace("\n", "\\n ") | |
| prompt_tokens = tokenizer.tokenize(prompt_text) | |
| prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens) | |
| prompt_tensor = torch.LongTensor(prompt_token_ids) | |
| prompt_tensor = prompt_tensor.view(1, -1).to(device) | |
| else: | |
| prompt_tensor = None | |
| # model forward | |
| output_sequences = model.generate( | |
| input_ids=prompt_tensor, | |
| max_length=512, | |
| top_p=0.95, | |
| top_k=40, | |
| temperature=1.0, | |
| do_sample=True, | |
| early_stopping=True, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| num_return_sequences=1 | |
| ) | |
| # convert model outputs to readable sentence | |
| generated_sequence = output_sequences.tolist()[0] | |
| generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence) | |
| generated_text = tokenizer.convert_tokens_to_string(generated_tokens) | |
| generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', '').replace('</s>', '\n\n---end---') | |
| title_and_lyric = generated_text.split("[CLS]", 1) | |
| if len(title_and_lyric) == 1: | |
| title, lyric = "", title_and_lyric[0].strip() | |
| else: | |
| title, lyric = title_and_lyric[0].strip(), title_and_lyric[1].strip() | |
| return title, lyric | |
| app = Flask(__name__, static_url_path="", static_folder="frontend/dist") | |
| def index_page(): | |
| return app.send_static_file("index.html") | |
| def generate(): | |
| if request.method == "POST": | |
| try: | |
| data = request.get_json() | |
| title = data['title'] | |
| text = data['text'] | |
| title, lyric = gen_lyric(title, text) | |
| result = { | |
| "state": 200, | |
| "title": title, | |
| "lyric": lyric | |
| } | |
| except Exception as e: | |
| result = { | |
| "state": 400, | |
| "msg": f"{e}" | |
| } | |
| return jsonify(result), result["state"] | |
| if __name__ == '__main__': | |
| app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False) | |