rahul7star commited on
Commit
0877c51
Β·
verified Β·
1 Parent(s): 4f932c2

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +198 -77
app_flash.py CHANGED
@@ -1,114 +1,235 @@
 
1
  import os
2
  import torch
 
 
 
3
  import gradio as gr
4
- from diffusers import DiffusionPipeline
5
- from flashpack.integrations.diffusers import (
6
- FlashPackDiffusersModelMixin,
7
- FlashPackDiffusionPipeline,
8
- )
9
- from huggingface_hub import snapshot_download
10
-
11
 
12
  # ============================================================
13
- # 🧠 Device setup (CPU fallback safe)
14
  # ============================================================
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- print(f"πŸ”§ Using device: {device}")
17
-
18
 
19
  # ============================================================
20
- # 🧩 Define FlashPack-integrated pipeline
21
  # ============================================================
22
- class FlashPackMyPipeline(DiffusionPipeline, FlashPackDiffusionPipeline):
23
- def __init__(self, *args, **kwargs):
24
- super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # ============================================================
28
- # πŸš€ Load FlashPack pipeline
29
  # ============================================================
30
- def load_flashpack_pipeline(repo_id: str = "rahul7star/FlashPack"):
31
- """
32
- Loads a FlashPack pipeline from Hugging Face Hub.
33
- Falls back to local snapshot if network or metadata issue occurs.
34
- """
35
- print(f"πŸ” Loading FlashPack pipeline from: {repo_id}")
36
 
37
- try:
38
- # Try direct hub load
39
- pipeline = FlashPackMyPipeline.from_pretrained_flashpack(repo_id)
40
- print("βœ… Successfully loaded FlashPack pipeline from Hugging Face Hub.")
41
- except Exception as e:
42
- print(f"⚠️ Hub load failed: {e}")
43
- print("⏬ Attempting to load via snapshot_download...")
44
- try:
45
- local_dir = snapshot_download(repo_id=repo_id)
46
- pipeline = FlashPackMyPipeline.from_pretrained_flashpack(local_dir)
47
- print(f"βœ… Loaded FlashPack pipeline from local snapshot: {local_dir}")
48
- except Exception as e2:
49
- raise RuntimeError(f"❌ Failed to load FlashPack model: {e2}")
50
 
51
- pipeline.to(device)
52
- return pipeline
 
 
53
 
 
 
 
 
 
54
 
55
  # ============================================================
56
- # πŸ§ͺ Inference function
57
  # ============================================================
58
- def generate_from_prompt(prompt: str):
59
- if not prompt or prompt.strip() == "":
60
- return "Please enter a valid prompt.", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
62
  try:
63
- output = pipeline(prompt)
64
- if hasattr(output, "images"):
65
- img = output.images[0]
66
- return f"βœ… Generated successfully!", img
67
- elif hasattr(output, "frames"):
68
- frames = output.frames
69
- video_path = "/tmp/generated.mp4"
70
- from diffusers.utils import export_to_video
71
- export_to_video(frames, video_path)
72
- return f"βœ… Video generated successfully!", video_path
73
- else:
74
- return "⚠️ Unknown output format.", None
75
  except Exception as e:
76
- return f"❌ Inference error: {str(e)}", None
77
-
 
 
 
 
78
 
79
  # ============================================================
80
- # βš™οΈ Load the model
81
  # ============================================================
82
  try:
83
- pipeline = load_flashpack_pipeline("rahul7star/FlashPack")
84
  except Exception as e:
85
- raise SystemExit(f"🚫 Failed to load model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # ============================================================
89
- # 🧠 Gradio UI
90
  # ============================================================
91
- with gr.Blocks(title="FlashPack Model – rahul7star/FlashPack", theme=gr.themes.Soft()) as demo:
92
- gr.Markdown("""
93
- # ⚑ FlashPack Model Inference
94
- - Loaded from **rahul7star/FlashPack**
95
- - Supports both image and video outputs (depending on model type)
96
- """)
 
 
97
 
98
  with gr.Row():
 
99
  with gr.Column(scale=1):
100
- prompt = gr.Textbox(label="Enter your prompt", placeholder="e.g. A robot painting in the rain")
101
- run_btn = gr.Button("πŸš€ Generate", variant="primary")
102
- with gr.Column(scale=1):
103
- result_msg = gr.Textbox(label="Status", interactive=False)
104
- image_out = gr.Image(label="Generated Image")
105
- video_out = gr.Video(label="Generated Video")
106
-
107
- run_btn.click(
108
- generate_from_prompt,
109
- inputs=[prompt],
110
- outputs=[result_msg, image_out],
111
- )
112
 
113
  # ============================================================
114
  # 🏁 Launch app
 
1
+ import gc
2
  import os
3
  import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import tempfile
7
  import gradio as gr
8
+ from datasets import load_dataset
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from flashpack import FlashPackMixin
11
+ from huggingface_hub import Repository
12
+ from typing import Tuple
 
 
13
 
14
  # ============================================================
15
+ # πŸ–₯ Device setup (CPU-only safe)
16
  # ============================================================
17
+ device = torch.device("cpu")
18
+ torch.set_num_threads(4)
19
+ print(f"πŸ”§ Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
+ # 1️⃣ Define FlashPack model
23
  # ============================================================
24
+ class GemmaTrainer(nn.Module, FlashPackMixin):
25
+ def __init__(self, input_dim: int = 768, hidden_dim: int = 512, output_dim: int = 768):
26
+ super().__init__()
27
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
28
+ self.relu = nn.ReLU()
29
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ x = self.fc1(x)
33
+ x = self.relu(x)
34
+ x = self.fc2(x)
35
+ return x
36
 
37
+ # ============================================================
38
+ # 2️⃣ Build tokenizer + encoder
39
+ # ============================================================
40
+ def build_encoder(model_name="gpt2", max_length: int = 32):
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+
45
+ embed_model = AutoModel.from_pretrained(model_name).to(device)
46
+ embed_model.eval()
47
+
48
+ @torch.no_grad()
49
+ def encode(prompt: str) -> torch.Tensor:
50
+ inputs = tokenizer(
51
+ prompt,
52
+ return_tensors="pt",
53
+ truncation=True,
54
+ padding="max_length",
55
+ max_length=max_length
56
+ ).to(device)
57
+ outputs = embed_model(**inputs).last_hidden_state.mean(dim=1)
58
+ return outputs.cpu()
59
+
60
+ return tokenizer, embed_model, encode
61
 
62
  # ============================================================
63
+ # 3️⃣ Push FlashPack model to HF
64
  # ============================================================
65
+ def push_flashpack_model_to_hf(model, hf_repo: str):
66
+ logs = []
67
+ with tempfile.TemporaryDirectory() as tmp_dir:
68
+ logs.append(f"πŸ“‚ Using temporary directory: {tmp_dir}")
69
+ repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
70
+ logs.append(f"🌐 Hugging Face repo cloned to: {tmp_dir}")
71
 
72
+ pack_path = os.path.join(tmp_dir, "model.flashpack")
73
+ logs.append(f"πŸ’Ύ Saving model to: {pack_path}")
74
+ model.save_flashpack(pack_path, target_dtype=torch.float32)
75
+ logs.append("βœ… Model saved successfully.")
 
 
 
 
 
 
 
 
 
76
 
77
+ readme_path = os.path.join(tmp_dir, "README.md")
78
+ with open(readme_path, "w") as f:
79
+ f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
80
+ logs.append("πŸ“„ README.md added.")
81
 
82
+ logs.append("πŸš€ Pushing repo to Hugging Face Hub...")
83
+ repo.push_to_hub()
84
+ logs.append(f"βœ… Model successfully pushed to: {hf_repo}")
85
+
86
+ return logs
87
 
88
  # ============================================================
89
+ # 4️⃣ Train FlashPack model
90
  # ============================================================
91
+ def train_flashpack_model(
92
+ dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
93
+ max_encode: int = 1000,
94
+ device: str = "cpu"
95
+ ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
96
+ print("πŸ“¦ Loading dataset...")
97
+ dataset = load_dataset(dataset_name, split="train")
98
+ limit = min(max_encode, len(dataset))
99
+ dataset = dataset.select(range(limit))
100
+ print(f"⚑ Encoding {len(dataset)} prompts (max {max_encode})")
101
+
102
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
103
+
104
+ short_list, long_list = [], []
105
+ for i, item in enumerate(dataset):
106
+ short_list.append(encode_fn(item["short_prompt"]))
107
+ long_list.append(encode_fn(item["long_prompt"]))
108
+ if (i+1) % 50 == 0 or (i+1) == len(dataset):
109
+ print(f" β†’ Encoded {i+1}/{limit} prompts")
110
+ gc.collect()
111
+
112
+ short_embeddings = torch.vstack(short_list)
113
+ long_embeddings = torch.vstack(long_list)
114
+ print(f"βœ… Finished encoding {short_embeddings.shape[0]} prompts")
115
+
116
+ # Build model
117
+ model = GemmaTrainer(
118
+ input_dim=short_embeddings.shape[1],
119
+ hidden_dim=min(512, short_embeddings.shape[1]),
120
+ output_dim=long_embeddings.shape[1]
121
+ ).to(device)
122
+
123
+ criterion = nn.MSELoss()
124
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
125
+ max_epochs = 20
126
+ batch_size = 32
127
+
128
+ print("πŸš€ Training model...")
129
+ n = short_embeddings.shape[0]
130
+ for epoch in range(max_epochs):
131
+ model.train()
132
+ epoch_loss = 0.0
133
+ perm = torch.randperm(n)
134
+ for start in range(0, n, batch_size):
135
+ idx = perm[start:start+batch_size]
136
+ inputs = short_embeddings[idx].to(device)
137
+ targets = long_embeddings[idx].to(device)
138
+
139
+ optimizer.zero_grad()
140
+ outputs = model(inputs)
141
+ loss = criterion(outputs, targets)
142
+ loss.backward()
143
+ optimizer.step()
144
+ epoch_loss += loss.item() * inputs.size(0)
145
+
146
+ epoch_loss /= n
147
+ if epoch % 5 == 0 or epoch == max_epochs-1:
148
+ print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
149
+
150
+ print("βœ… Training finished!")
151
+ return model, dataset, embed_model, tokenizer, long_embeddings
152
 
153
+ # ============================================================
154
+ # 5️⃣ Load FlashPack model (train if missing)
155
+ # ============================================================
156
+ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
157
  try:
158
+ print(f"πŸ” Attempting to load FlashPack model from {hf_repo}")
159
+ model = GemmaTrainer.from_flashpack(hf_repo)
160
+ model.eval()
161
+ print("βœ… Loaded model successfully from HF")
162
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
163
+ return model, tokenizer, embed_model
 
 
 
 
 
 
164
  except Exception as e:
165
+ print(f"⚠️ Load failed: {e}")
166
+ print("⏬ Training a new FlashPack model locally...")
167
+ model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
168
+ print("πŸ“€ Pushing trained model to HF...")
169
+ push_flashpack_model_to_hf(model, hf_repo)
170
+ return model, tokenizer, embed_model, dataset, long_embeddings
171
 
172
  # ============================================================
173
+ # 6️⃣ Load or train
174
  # ============================================================
175
  try:
176
+ model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
177
  except Exception as e:
178
+ raise SystemExit(f"❌ Failed to load or train FlashPack model: {e}")
179
+
180
+ # ============================================================
181
+ # 7️⃣ Inference helpers
182
+ # ============================================================
183
+ @torch.no_grad()
184
+ def encode_for_inference(prompt: str) -> torch.Tensor:
185
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
186
+ padding="max_length", max_length=32).to(device)
187
+ return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
188
+
189
+ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
190
+ chat_history = chat_history or []
191
+ short_emb = encode_for_inference(user_prompt)
192
+ mapped = model(short_emb.to(device)).cpu()
193
 
194
+ sims = (long_embeddings @ mapped.t()).squeeze(1)
195
+ long_norms = long_embeddings.norm(dim=1)
196
+ mapped_norm = mapped.norm()
197
+ sims = sims / (long_norms * (mapped_norm + 1e-12))
198
+
199
+ best_idx = int(sims.argmax().item())
200
+ enhanced_prompt = dataset[best_idx]["long_prompt"]
201
+
202
+ chat_history.append({"role": "user", "content": user_prompt})
203
+ chat_history.append({"role": "assistant", "content": enhanced_prompt})
204
+ return chat_history
205
 
206
  # ============================================================
207
+ # 8️⃣ Gradio UI
208
  # ============================================================
209
+ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
210
+ gr.Markdown(
211
+ """
212
+ # ✨ Prompt Enhancer (FlashPack mapper)
213
+ Enter a short prompt, and the model will **expand it with details and creative context**.
214
+ (CPU-only mode.)
215
+ """
216
+ )
217
 
218
  with gr.Row():
219
+ chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
220
  with gr.Column(scale=1):
221
+ user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
222
+ temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
223
+ max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
224
+ send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
225
+ clear_btn = gr.Button("🧹 Clear Chat")
226
+
227
+ send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
228
+ user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
229
+ clear_btn.click(lambda: [], None, chatbot)
230
+
231
+ # ============================================================
232
+ # 9️⃣ Launch
233
 
234
  # ============================================================
235
  # 🏁 Launch app