File size: 4,050 Bytes
fcbe832 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
print("Creating model...")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
model = Transformer(model_args)
# Original copyright by tloen
# https://github.com/tloen/llama-int8/blob/main/example.py
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i: size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.to("cpu")
generator = LLaMA(model, tokenizer)
print(f"Loaded model in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str = './model',
tokenizer_path: str = './tokenizer/tokenizer.model',
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512, # up to 2048
max_batch_size: int = 32,
):
# torch.manual_seed(1)
# torch.set_default_dtype(torch.bfloat16)
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
prompts = [
##### For these prompts, the expected answer is the natural continuation of the prompt #####
"I believe the meaning of life is",
# "Simply put, the theory of relativity states that ",
# "Building a website can be done in 10 simple steps:\n",
##### Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api #####
# """Tweet: "I hate it when my phone battery dies."
# Sentiment: Negative
# ###
# Tweet: "My day has been 👍"
# Sentiment: Positive
# ###
# Tweet: "This is the link to the article"
# Sentiment: Neutral
# ###
# Tweet: "This new music video was incredibile"
# Sentiment:""",
# """Translate English to French:
# sea otter => loutre de mer
# peppermint => menthe poivrée
# plush girafe => girafe peluche
# cheese =>""",
]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)
for result in results:
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)
|