| | import torch
|
| | import torch.nn as nn
|
| | import torch.optim as optim
|
| | import pandas as pd
|
| | from transformers import T5Tokenizer
|
| | from sentence_transformers import SentenceTransformer
|
| |
|
| |
|
| | INPUT_FILE = "chat_1turn.csv"
|
| | EMB_FILE = "chat_embeddings.pt"
|
| | MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| | EPOCHS = 80
|
| | BATCH_SIZE = 16
|
| | HIDDEN_DIM = 512
|
| | MAX_LEN = 64
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | print(f"Using device: {device}")
|
| |
|
| |
|
| | df = pd.read_csv(INPUT_FILE)
|
| | sources = df["source"].fillna("").tolist()
|
| | targets = df["target"].fillna("").tolist()
|
| |
|
| |
|
| | tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| | target_enc = tokenizer(targets, padding=True, truncation=True,
|
| | return_tensors="pt", max_length=MAX_LEN)
|
| | input_ids = target_enc["input_ids"].to(device)
|
| | attention_mask = target_enc["attention_mask"].to(device)
|
| |
|
| |
|
| | emb_data = torch.load(EMB_FILE)
|
| | x_embeddings = emb_data["source"].to(device)
|
| | y_embeddings = emb_data["target"].to(device)
|
| |
|
| |
|
| | class EmbeddingDecoder(nn.Module):
|
| | def __init__(self, input_dim, hidden_dim, vocab_size):
|
| | super().__init__()
|
| | self.bridge = nn.Linear(input_dim, hidden_dim)
|
| | self.embed = nn.Embedding(vocab_size, hidden_dim)
|
| | self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
| | self.fc = nn.Linear(hidden_dim, vocab_size)
|
| |
|
| | def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
|
| | hidden = self.bridge(emb_vec).unsqueeze(0)
|
| | B = emb_vec.size(0)
|
| | outputs = []
|
| |
|
| |
|
| | inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)
|
| |
|
| | for t in range(max_len):
|
| | inp_emb = self.embed(inp)
|
| | out, hidden = self.gru(inp_emb, hidden)
|
| | logits = self.fc(out.squeeze(1))
|
| | outputs.append(logits.unsqueeze(1))
|
| |
|
| | if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
|
| | inp = target_ids[:, t].unsqueeze(1)
|
| | else:
|
| | inp = torch.argmax(logits, dim=-1, keepdim=True)
|
| |
|
| | return torch.cat(outputs, dim=1)
|
| |
|
| |
|
| | decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
| | optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
|
| | criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
| |
|
| | print("Training decoder...")
|
| | for epoch in range(EPOCHS):
|
| | decoder.train()
|
| | total_loss = 0.0
|
| | for i in range(0, len(y_embeddings), BATCH_SIZE):
|
| | xb = y_embeddings[i:i+BATCH_SIZE]
|
| | yb = input_ids[i:i+BATCH_SIZE]
|
| |
|
| | optimizer.zero_grad()
|
| | logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
|
| | loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
|
| | loss.backward()
|
| | optimizer.step()
|
| | total_loss += loss.item()
|
| |
|
| | print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
|
| |
|
| |
|
| | embedder = SentenceTransformer(MODEL_NAME, device=device)
|
| |
|
| | def generate(text, max_len=30, use_mapper=False, mapper=None):
|
| | with torch.no_grad():
|
| |
|
| | emb = embedder.encode([text], convert_to_tensor=True, device=device)
|
| | if use_mapper and mapper is not None:
|
| | emb = mapper(emb)
|
| | logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
|
| | ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
|
| | return tokenizer.decode(ids, skip_special_tokens=True)
|
| |
|
| |
|
| | print("Hi ->", generate("Hi"))
|
| |
|