rahul7star commited on
Commit
ba4b2f5
Β·
verified Β·
1 Parent(s): 2781180

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +74 -184
app_flash.py CHANGED
@@ -1,227 +1,117 @@
1
  import os
2
- import re
3
  import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- from typing import Tuple
7
- from datasets import load_dataset
8
- from flashpack import FlashPackMixin
9
- from huggingface_hub import HfApi, create_repo, repo_exists
10
  import gradio as gr
11
- from transformers import AutoTokenizer, AutoModel
12
-
13
- # ============================================================
14
- # βš™οΈ Setup
15
- # ============================================================
16
- device = torch.device("cpu")
17
- torch.set_num_threads(4)
18
- print(f"πŸ”§ Using device: {device} (CPU-only mode)")
19
-
20
- HF_REPO = "rahul7star/FlashPack"
21
- MODEL_ID = HF_REPO
22
-
23
-
24
- # ============================================================
25
- # 🧠 Define FlashPack Trainer
26
- # ============================================================
27
- class GemmaTrainer(nn.Module, FlashPackMixin):
28
- def __init__(self, input_dim=768, hidden_dim=512, output_dim=768):
29
- super().__init__()
30
- self.fc1 = nn.Linear(input_dim, hidden_dim)
31
- self.relu = nn.ReLU()
32
- self.fc2 = nn.Linear(hidden_dim, output_dim)
33
-
34
- def forward(self, x):
35
- return self.fc2(self.relu(self.fc1(x)))
36
 
37
 
38
  # ============================================================
39
- # πŸ”€ Encoder Builder (GPT2 base)
40
  # ============================================================
41
- def build_encoder(model_name="gpt2", max_length=32):
42
- tokenizer = AutoTokenizer.from_pretrained(model_name)
43
- if tokenizer.pad_token is None:
44
- tokenizer.pad_token = tokenizer.eos_token
45
- embed_model = AutoModel.from_pretrained(model_name).to(device)
46
- embed_model.eval()
47
-
48
- @torch.no_grad()
49
- def encode(text: str):
50
- inputs = tokenizer(
51
- text,
52
- return_tensors="pt",
53
- truncation=True,
54
- padding="max_length",
55
- max_length=max_length,
56
- ).to(device)
57
- return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
58
-
59
- return tokenizer, embed_model, encode
60
 
61
 
62
  # ============================================================
63
- # 🧩 FlashPack: Train and Upload (uses Gemma only internally)
64
  # ============================================================
65
- def train_flashpack_model(hf_repo=HF_REPO):
66
- print(f"πŸš€ Training new FlashPack model for repo: {hf_repo}")
67
- model = GemmaTrainer()
68
- tokenizer, embed_model, encode = build_encoder("gpt2")
69
-
70
- # Load dataset (Gemma-expanded dataset)
71
- dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
72
-
73
- # Compute embeddings for training (short β†’ long)
74
- X, Y = [], []
75
- for p in dataset.select(range(300)):
76
- short_emb = encode(p["short_prompt"])
77
- long_emb = encode(p["long_prompt"])
78
- X.append(short_emb)
79
- Y.append(long_emb)
80
-
81
- X = torch.vstack(X)
82
- Y = torch.vstack(Y)
83
-
84
- optimizer = optim.Adam(model.parameters(), lr=1e-3)
85
- for epoch in range(10):
86
- out = model(X)
87
- loss = nn.MSELoss()(out, Y)
88
- optimizer.zero_grad()
89
- loss.backward()
90
- optimizer.step()
91
- print(f"Epoch {epoch+1}/10 | Loss: {loss.item():.6f}")
92
-
93
- # Save FlashPack model and push
94
- model.to_flashpack("flashpack_model")
95
- print("πŸ’Ύ Model saved locally. Uploading to Hugging Face...")
96
-
97
- api = HfApi()
98
- if not repo_exists(hf_repo):
99
- create_repo(hf_repo, repo_type="model", exist_ok=True)
100
- model.push_to_hub(hf_repo, commit_message="Initial FlashPack model training")
101
-
102
- print(f"βœ… Model uploaded successfully to {hf_repo}")
103
- return model, tokenizer, embed_model
104
 
105
 
106
  # ============================================================
107
- # πŸ“¦ Load FlashPack from Hub
108
  # ============================================================
109
- from huggingface_hub import snapshot_download
110
- import os
111
-
112
- def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
113
- print(f"πŸ” Loading FlashPack model from: {hf_repo}")
114
-
115
- # Try local first, then Hugging Face Hub
116
- if os.path.isdir(hf_repo):
117
- local_dir = hf_repo
118
- print(f"πŸ“‚ Using local FlashPack model at: {local_dir}")
119
- else:
120
- print("☁️ Downloading FlashPack model from Hugging Face Hub...")
121
- local_dir = snapshot_download(repo_id=hf_repo)
122
- print(f"πŸ“₯ Model snapshot downloaded to: {local_dir}")
123
-
124
- # Load from local directory
125
- model = GemmaTrainer.from_flashpack(local_dir)
126
- model.eval()
127
- print("βœ… FlashPack model loaded successfully.")
128
- return model
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
 
131
 
132
 
133
  # ============================================================
134
- # ⚑ Auto Load or Train
135
  # ============================================================
136
- def get_flashpack_model(hf_repo=HF_REPO):
 
 
 
137
  try:
138
- api = HfApi()
139
- if repo_exists(hf_repo):
140
- print("βœ… Found trained model on Hub.")
141
- return load_flashpack_model(hf_repo)
 
 
 
 
 
 
142
  else:
143
- print("❌ Model not found, training new one using Gemma dataset...")
144
- return train_flashpack_model(hf_repo)
145
  except Exception as e:
146
- print(f"⚠️ Repo check failed: {e}. Retraining model locally.")
147
- return train_flashpack_model(hf_repo)
148
 
149
 
150
  # ============================================================
151
- # πŸ“š Dataset + Model
152
  # ============================================================
153
- model, tokenizer, embed_model = get_flashpack_model()
154
- dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
155
-
156
- long_embeddings = torch.vstack([
157
- embed_model(**tokenizer(
158
- p["long_prompt"], return_tensors="pt",
159
- truncation=True, padding="max_length", max_length=32
160
- )).last_hidden_state.mean(dim=1).cpu()
161
- for p in dataset.select(range(min(500, len(dataset))))
162
- ])
163
-
164
- print("βœ… FlashPack model and embeddings loaded.")
165
 
166
 
167
  # ============================================================
168
- # 🧠 Inference Helpers
169
  # ============================================================
170
- @torch.no_grad()
171
- def encode_for_inference(prompt: str):
172
- inputs = tokenizer(
173
- prompt,
174
- return_tensors="pt",
175
- truncation=True,
176
- padding="max_length",
177
- max_length=32,
178
- ).to(device)
179
- return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
180
-
181
-
182
- @torch.no_grad()
183
- def enhance_prompt_flashpack(user_prompt: str, temperature: float, max_tokens: int, chat_history):
184
- chat_history = chat_history or []
185
- short_emb = encode_for_inference(user_prompt)
186
- mapped = model(short_emb.to(device)).cpu()
187
-
188
- sims = (long_embeddings @ mapped.t()).squeeze(1)
189
- sims /= (long_embeddings.norm(dim=1) * (mapped.norm() + 1e-12))
190
- best_idx = int(sims.argmax().item())
191
- enhanced_prompt = dataset[best_idx]["long_prompt"]
192
-
193
- chat_history.append({"role": "user", "content": user_prompt})
194
- chat_history.append({"role": "assistant", "content": enhanced_prompt})
195
- return chat_history
196
-
197
-
198
- # ============================================================
199
- # πŸ’¬ Gradio UI
200
- # ============================================================
201
- with gr.Blocks(title="Prompt Enhancer – FlashPack Only", theme=gr.themes.Soft()) as demo:
202
  gr.Markdown("""
203
- # ✨ FlashPack Prompt Enhancer
204
- - Uses pre-trained **FlashPack model** (`rahul7star/FlashPack`)
205
- - Matches short prompts to enhanced long prompts using learned embeddings
206
- - CPU-only, no Gemma dependency during inference.
207
  """)
208
 
209
  with gr.Row():
210
- chatbot = gr.Chatbot(height=420, label="Enhanced Prompts", type="messages")
211
  with gr.Column(scale=1):
212
- user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
213
- temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
214
- max_tokens = gr.Slider(32, 512, value=256, label="Max Tokens")
215
- send_flashpack = gr.Button("πŸ”— Enhance Prompt")
216
- clear_btn = gr.Button("🧹 Clear Chat")
217
-
218
- send_flashpack.click(enhance_prompt_flashpack, [user_prompt, temperature, max_tokens, chatbot], chatbot)
219
- user_prompt.submit(enhance_prompt_flashpack, [user_prompt, temperature, max_tokens, chatbot], chatbot)
220
- clear_btn.click(lambda: [], None, chatbot)
221
 
 
 
 
 
 
222
 
223
  # ============================================================
224
- # πŸš€ Launch App
225
  # ============================================================
226
  if __name__ == "__main__":
227
  demo.launch(show_error=True)
 
1
  import os
 
2
  import torch
 
 
 
 
 
 
3
  import gradio as gr
4
+ from diffusers import DiffusionPipeline
5
+ from flashpack.integrations.diffusers import (
6
+ FlashPackDiffusersModelMixin,
7
+ FlashPackDiffusionPipeline,
8
+ )
9
+ from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  # ============================================================
13
+ # 🧠 Device setup (CPU fallback safe)
14
  # ============================================================
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"πŸ”§ Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  # ============================================================
20
+ # 🧩 Define FlashPack-integrated pipeline
21
  # ============================================================
22
+ class FlashPackMyPipeline(DiffusionPipeline, FlashPackDiffusionPipeline):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  # ============================================================
28
+ # πŸš€ Load FlashPack pipeline
29
  # ============================================================
30
+ def load_flashpack_pipeline(repo_id: str = "rahul7star/FlashPack"):
31
+ """
32
+ Loads a FlashPack pipeline from Hugging Face Hub.
33
+ Falls back to local snapshot if network or metadata issue occurs.
34
+ """
35
+ print(f"πŸ” Loading FlashPack pipeline from: {repo_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ try:
38
+ # Try direct hub load
39
+ pipeline = FlashPackMyPipeline.from_pretrained_flashpack(repo_id)
40
+ print("βœ… Successfully loaded FlashPack pipeline from Hugging Face Hub.")
41
+ except Exception as e:
42
+ print(f"⚠️ Hub load failed: {e}")
43
+ print("⏬ Attempting to load via snapshot_download...")
44
+ try:
45
+ local_dir = snapshot_download(repo_id=repo_id)
46
+ pipeline = FlashPackMyPipeline.from_pretrained_flashpack(local_dir)
47
+ print(f"βœ… Loaded FlashPack pipeline from local snapshot: {local_dir}")
48
+ except Exception as e2:
49
+ raise RuntimeError(f"❌ Failed to load FlashPack model: {e2}")
50
 
51
+ pipeline.to(device)
52
+ return pipeline
53
 
54
 
55
  # ============================================================
56
+ # πŸ§ͺ Inference function
57
  # ============================================================
58
+ def generate_from_prompt(prompt: str):
59
+ if not prompt or prompt.strip() == "":
60
+ return "Please enter a valid prompt.", None
61
+
62
  try:
63
+ output = pipeline(prompt)
64
+ if hasattr(output, "images"):
65
+ img = output.images[0]
66
+ return f"βœ… Generated successfully!", img
67
+ elif hasattr(output, "frames"):
68
+ frames = output.frames
69
+ video_path = "/tmp/generated.mp4"
70
+ from diffusers.utils import export_to_video
71
+ export_to_video(frames, video_path)
72
+ return f"βœ… Video generated successfully!", video_path
73
  else:
74
+ return "⚠️ Unknown output format.", None
 
75
  except Exception as e:
76
+ return f"❌ Inference error: {str(e)}", None
 
77
 
78
 
79
  # ============================================================
80
+ # βš™οΈ Load the model
81
  # ============================================================
82
+ try:
83
+ pipeline = load_flashpack_pipeline("rahul7star/FlashPack")
84
+ except Exception as e:
85
+ raise SystemExit(f"🚫 Failed to load model: {e}")
 
 
 
 
 
 
 
 
86
 
87
 
88
  # ============================================================
89
+ # 🧠 Gradio UI
90
  # ============================================================
91
+ with gr.Blocks(title="FlashPack Model – rahul7star/FlashPack", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  gr.Markdown("""
93
+ # ⚑ FlashPack Model Inference
94
+ - Loaded from **rahul7star/FlashPack**
95
+ - Supports both image and video outputs (depending on model type)
 
96
  """)
97
 
98
  with gr.Row():
 
99
  with gr.Column(scale=1):
100
+ prompt = gr.Textbox(label="Enter your prompt", placeholder="e.g. A robot painting in the rain")
101
+ run_btn = gr.Button("πŸš€ Generate", variant="primary")
102
+ with gr.Column(scale=1):
103
+ result_msg = gr.Textbox(label="Status", interactive=False)
104
+ image_out = gr.Image(label="Generated Image")
105
+ video_out = gr.Video(label="Generated Video")
 
 
 
106
 
107
+ run_btn.click(
108
+ generate_from_prompt,
109
+ inputs=[prompt],
110
+ outputs=[result_msg, image_out],
111
+ )
112
 
113
  # ============================================================
114
+ # 🏁 Launch app
115
  # ============================================================
116
  if __name__ == "__main__":
117
  demo.launch(show_error=True)