aladdin1995 commited on
Commit
bd177b1
·
verified ·
1 Parent(s): 327846c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -118
app.py CHANGED
@@ -53,10 +53,10 @@ class PromptEnhancerV2:
53
  self,
54
  prompt_cot,
55
  sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
56
- temperature=0.0,
57
  top_p=1.0,
58
  max_new_tokens=2048,
59
- device="cuda:0",
60
  ):
61
  org_prompt_cot = prompt_cot
62
  try:
@@ -112,91 +112,6 @@ class PromptEnhancerV2:
112
  print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
113
 
114
  return prompt_cot
115
- # @torch.inference_mode()
116
- @spaces.GPU
117
- def predict_stream(
118
- self,
119
- prompt_cot,
120
- sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
121
- temperature=0.1,
122
- top_p=1.0,
123
- max_new_tokens=2048,
124
- device="cuda:0",
125
- ):
126
- org_prompt_cot = prompt_cot
127
-
128
- # 组装输入,同 predict
129
- user_prompt_format = sys_prompt + "\n" + org_prompt_cot
130
- messages = [{"role": "user", "content": [{"type": "text", "text": user_prompt_format}]}]
131
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
132
- image_inputs, video_inputs = process_vision_info(messages)
133
- inputs = self.processor(
134
- text=[text],
135
- images=image_inputs,
136
- videos=video_inputs,
137
- padding=True,
138
- return_tensors="pt",
139
- )
140
- inputs = inputs.to(device)
141
-
142
- # 取得 tokenizer(大多数情况下 processor.tokenizer 就有;加一个后备以防万一)
143
- tokenizer = getattr(self.processor, "tokenizer", None)
144
- if tokenizer is None:
145
- tokenizer = AutoTokenizer.from_pretrained(self.models_root_path, trust_remote_code=True)
146
-
147
- streamer = TextIteratorStreamer(
148
- tokenizer=tokenizer,
149
- skip_special_tokens=True,
150
- clean_up_tokenization_spaces=False,
151
- )
152
-
153
- gen_kwargs = dict(
154
- **inputs,
155
- max_new_tokens=max_new_tokens,
156
- temperature=float(temperature),
157
- do_sample=True, # 与原逻辑一致; 若要采样流式把这里改为 True
158
- top_k=5,
159
- top_p=0.9,
160
- streamer=streamer,
161
- )
162
-
163
- # 子线程启动生成;主线程消费 streamer
164
- thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
165
- thread.start()
166
-
167
- buffer = "" # 累积完整输出(含思考)
168
- emitted = "" # 已对外输出的“重写提示词”部分
169
- already_stripped_newline = False
170
-
171
- try:
172
- for piece in streamer:
173
- buffer += piece
174
- part = buffer.split('assistant')[-1]
175
- delta = part[len(emitted):]
176
- if delta:
177
- emitted = part
178
- yield emitted # 将中间结果送给前端
179
- finally:
180
- thread.join()
181
-
182
- # 如果始终没等到第二个 think>,回退到原始 prompt
183
- # if emitted.strip() == "":
184
- # yield replace_single_quotes(org_prompt_cot)
185
- try:
186
- assert emitted.count("think>") == 2
187
- prompt_cot = emitted.split("think>")[-1]
188
- if prompt_cot.startswith("\n"):
189
- prompt_cot = prompt_cot[1:]
190
- prompt_cot = emitted.split('assistant')[-1] + '\n \n Recaption:'+replace_single_quotes(prompt_cot)
191
- # prompt_cot = replace_single_quotes(prompt_cot)
192
- yield prompt_cot
193
- except Exception as e:
194
- prompt_cot = org_prompt_cot
195
- print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
196
- yield prompt_cot
197
-
198
-
199
-
200
  # -------------------------
201
  # Gradio app helpers
202
  # -------------------------
@@ -223,32 +138,6 @@ def ensure_enhancer(state, model_path, device_map, torch_dtype):
223
  return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
224
  return state
225
 
226
- def stream_single(prompt, sys_prompt, temperature, max_new_tokens, device,
227
- model_path, device_map, torch_dtype, state):
228
- if not prompt or not str(prompt).strip():
229
- yield "", "请先输入提示词。", state
230
- return
231
-
232
- t0 = time.time()
233
- state = ensure_enhancer(state, model_path, device_map, torch_dtype)
234
- enhancer = state["enhancer"]
235
-
236
- emitted = ""
237
- try:
238
- for chunk in enhancer.predict_stream(
239
- prompt_cot=prompt,
240
- sys_prompt=sys_prompt,
241
- temperature=temperature,
242
- max_new_tokens=max_new_tokens,
243
- device=device
244
- ):
245
- emitted = chunk
246
- info = f"已接收 {len(emitted)} 字符,用时 {time.time()-t0:.2f}s"
247
- yield emitted, info, state
248
- # 结束时再给一次最终状态(可选)
249
- yield emitted, f"完成。总耗时 {time.time()-t0:.2f}s", state
250
- except Exception as e:
251
- yield "", f"推理失败:{e}", state
252
 
253
  def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
254
  model_path, device_map, torch_dtype, state):
@@ -295,11 +184,12 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
295
  value=DEFAULT_MODEL_PATH,
296
  placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
297
  )
298
- device_map = gr.Dropdown(
299
- choices=["auto", "cuda", "cpu"],
300
- value="auto",
301
- label="device_map(模型加载映射)"
302
- )
 
303
  torch_dtype = gr.Dropdown(
304
  choices=["bfloat16", "float16", "float32"],
305
  value="bfloat16",
 
53
  self,
54
  prompt_cot,
55
  sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
56
+ temperature=0.1,
57
  top_p=1.0,
58
  max_new_tokens=2048,
59
+ device="cuda",
60
  ):
61
  org_prompt_cot = prompt_cot
62
  try:
 
112
  print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
113
 
114
  return prompt_cot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # -------------------------
116
  # Gradio app helpers
117
  # -------------------------
 
138
  return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
139
  return state
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
143
  model_path, device_map, torch_dtype, state):
 
184
  value=DEFAULT_MODEL_PATH,
185
  placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
186
  )
187
+ device_map = "cuda"
188
+ # gr.Dropdown(
189
+ # choices=["auto", "cuda", "cpu"],
190
+ # value="auto",
191
+ # label="device_map(模型加载映射)"
192
+ # )
193
  torch_dtype = gr.Dropdown(
194
  choices=["bfloat16", "float16", "float32"],
195
  value="bfloat16",