rahul7star commited on
Commit
a4c1e96
·
verified ·
1 Parent(s): a8678a6

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +25 -25
app_flash.py CHANGED
@@ -4,6 +4,7 @@ import torch.optim as optim
4
  from flashpack import FlashPackMixin
5
  from datasets import load_dataset
6
  import gradio as gr
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
@@ -28,27 +29,34 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
28
  # ============================================================
29
  dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
30
 
31
- # Example: convert short_prompt and long_prompt to embeddings
32
- from transformers import AutoTokenizer, AutoModel
 
33
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
 
34
  embed_model = AutoModel.from_pretrained("gpt2").to(device)
 
35
 
36
  def encode_prompt(prompt):
37
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=32).to(device)
38
  with torch.no_grad():
39
  return embed_model(**inputs).last_hidden_state.mean(dim=1)
40
 
41
- short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset])
42
- long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset])
 
 
 
43
 
44
  # ============================================================
45
- # 3️⃣ Train FlashPack model
46
  # ============================================================
47
  model = GemmaTrainer(input_dim=short_embeddings.shape[1], output_dim=long_embeddings.shape[1]).to(device)
48
  criterion = nn.MSELoss()
49
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
50
 
51
- max_epochs = 1000
52
  tolerance = 1e-4
53
 
54
  for epoch in range(max_epochs):
@@ -58,57 +66,52 @@ for epoch in range(max_epochs):
58
  loss.backward()
59
  optimizer.step()
60
  if loss.item() < tolerance:
61
- print(f"Training converged at epoch {epoch+1}")
62
  break
63
- if epoch % 50 == 0:
64
- print(f"Epoch {epoch+1}, Loss: {loss.item():.6f}")
65
 
66
  # ============================================================
67
- # 4️⃣ Save to FlashPack Hub
68
  # ============================================================
69
  FLASHPACK_REPO = "rahul7star/FlashPack"
70
  model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
71
- print("✅ Model saved to FlashPack Hub!")
72
 
73
  # ============================================================
74
- # 5️⃣ Load FlashPack model
75
  # ============================================================
76
  loaded_model = model.from_flashpack(FLASHPACK_REPO)
77
 
78
  # ============================================================
79
- # 6️⃣ Gradio interface
80
  # ============================================================
81
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
82
  chat_history = chat_history or []
83
 
84
- # Encode short prompt
85
  short_emb = encode_prompt(user_prompt)
86
-
87
- # Generate expanded embedding via trained model
88
  with torch.no_grad():
89
  long_emb = loaded_model(short_emb)
90
 
91
- # Decode embedding back to text (approximate via nearest training example)
92
- # Simple approach: cosine similarity to long_embeddings
93
  cos = nn.CosineSimilarity(dim=1)
94
  sims = cos(long_emb.repeat(len(long_embeddings),1), long_embeddings)
95
  best_idx = sims.argmax()
96
  enhanced_prompt = dataset[best_idx]["long_prompt"]
97
 
98
- # Update chat history
99
  chat_history.append({"role": "user", "content": user_prompt})
100
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
101
  return chat_history
102
 
103
  # ============================================================
104
- # 7️⃣ Gradio UI
105
  # ============================================================
106
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
107
  gr.Markdown(
108
  """
109
  # ✨ Prompt Enhancer (Gemma 3 270M)
110
- Enter a short prompt, and the model will **expand it with details and creative context**
111
- using the Gemma chat-template interface.
112
  """
113
  )
114
 
@@ -138,8 +141,5 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
138
  """
139
  )
140
 
141
- # ============================================================
142
- # 8️⃣ Launch
143
- # ============================================================
144
  if __name__ == "__main__":
145
  demo.launch(show_error=True)
 
4
  from flashpack import FlashPackMixin
5
  from datasets import load_dataset
6
  import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModel
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
29
  # ============================================================
30
  dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
31
 
32
+ # ============================================================
33
+ # 3️⃣ Prepare tokenizer & embedding model
34
+ # ============================================================
35
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
36
+ tokenizer.pad_token = tokenizer.eos_token # FIX padding error
37
+
38
  embed_model = AutoModel.from_pretrained("gpt2").to(device)
39
+ embed_model.eval() # inference only
40
 
41
  def encode_prompt(prompt):
42
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=32).to(device)
43
  with torch.no_grad():
44
  return embed_model(**inputs).last_hidden_state.mean(dim=1)
45
 
46
+ # Encode all dataset prompts
47
+ print("📦 Encoding dataset prompts...")
48
+ short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
49
+ long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
50
+ print(f"✅ Encoded {len(dataset)} prompts")
51
 
52
  # ============================================================
53
+ # 4️⃣ Train FlashPack model
54
  # ============================================================
55
  model = GemmaTrainer(input_dim=short_embeddings.shape[1], output_dim=long_embeddings.shape[1]).to(device)
56
  criterion = nn.MSELoss()
57
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
58
 
59
+ max_epochs = 500
60
  tolerance = 1e-4
61
 
62
  for epoch in range(max_epochs):
 
66
  loss.backward()
67
  optimizer.step()
68
  if loss.item() < tolerance:
69
+ print(f" Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
70
  break
71
+ if (epoch + 1) % 50 == 0:
72
+ print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")
73
 
74
  # ============================================================
75
+ # 5️⃣ Save FlashPack model to Hub
76
  # ============================================================
77
  FLASHPACK_REPO = "rahul7star/FlashPack"
78
  model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
79
+ print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")
80
 
81
  # ============================================================
82
+ # 6️⃣ Load FlashPack model
83
  # ============================================================
84
  loaded_model = model.from_flashpack(FLASHPACK_REPO)
85
 
86
  # ============================================================
87
+ # 7️⃣ Gradio interface
88
  # ============================================================
89
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
90
  chat_history = chat_history or []
91
 
92
+ # Encode user prompt
93
  short_emb = encode_prompt(user_prompt)
 
 
94
  with torch.no_grad():
95
  long_emb = loaded_model(short_emb)
96
 
97
+ # Find nearest matching long prompt in dataset (simple approach)
 
98
  cos = nn.CosineSimilarity(dim=1)
99
  sims = cos(long_emb.repeat(len(long_embeddings),1), long_embeddings)
100
  best_idx = sims.argmax()
101
  enhanced_prompt = dataset[best_idx]["long_prompt"]
102
 
 
103
  chat_history.append({"role": "user", "content": user_prompt})
104
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
105
  return chat_history
106
 
107
  # ============================================================
108
+ # 8️⃣ Gradio UI
109
  # ============================================================
110
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
111
  gr.Markdown(
112
  """
113
  # ✨ Prompt Enhancer (Gemma 3 270M)
114
+ Enter a short prompt, and the model will **expand it with details and creative context**
 
115
  """
116
  )
117
 
 
141
  """
142
  )
143
 
 
 
 
144
  if __name__ == "__main__":
145
  demo.launch(show_error=True)