Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -53,10 +53,10 @@ class PromptEnhancerV2:
|
|
| 53 |
self,
|
| 54 |
prompt_cot,
|
| 55 |
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
|
| 56 |
-
temperature=0.
|
| 57 |
top_p=1.0,
|
| 58 |
max_new_tokens=2048,
|
| 59 |
-
device="cuda
|
| 60 |
):
|
| 61 |
org_prompt_cot = prompt_cot
|
| 62 |
try:
|
|
@@ -112,91 +112,6 @@ class PromptEnhancerV2:
|
|
| 112 |
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
|
| 113 |
|
| 114 |
return prompt_cot
|
| 115 |
-
# @torch.inference_mode()
|
| 116 |
-
@spaces.GPU
|
| 117 |
-
def predict_stream(
|
| 118 |
-
self,
|
| 119 |
-
prompt_cot,
|
| 120 |
-
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
|
| 121 |
-
temperature=0.1,
|
| 122 |
-
top_p=1.0,
|
| 123 |
-
max_new_tokens=2048,
|
| 124 |
-
device="cuda:0",
|
| 125 |
-
):
|
| 126 |
-
org_prompt_cot = prompt_cot
|
| 127 |
-
|
| 128 |
-
# 组装输入,同 predict
|
| 129 |
-
user_prompt_format = sys_prompt + "\n" + org_prompt_cot
|
| 130 |
-
messages = [{"role": "user", "content": [{"type": "text", "text": user_prompt_format}]}]
|
| 131 |
-
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 132 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 133 |
-
inputs = self.processor(
|
| 134 |
-
text=[text],
|
| 135 |
-
images=image_inputs,
|
| 136 |
-
videos=video_inputs,
|
| 137 |
-
padding=True,
|
| 138 |
-
return_tensors="pt",
|
| 139 |
-
)
|
| 140 |
-
inputs = inputs.to(device)
|
| 141 |
-
|
| 142 |
-
# 取得 tokenizer(大多数情况下 processor.tokenizer 就有;加一个后备以防万一)
|
| 143 |
-
tokenizer = getattr(self.processor, "tokenizer", None)
|
| 144 |
-
if tokenizer is None:
|
| 145 |
-
tokenizer = AutoTokenizer.from_pretrained(self.models_root_path, trust_remote_code=True)
|
| 146 |
-
|
| 147 |
-
streamer = TextIteratorStreamer(
|
| 148 |
-
tokenizer=tokenizer,
|
| 149 |
-
skip_special_tokens=True,
|
| 150 |
-
clean_up_tokenization_spaces=False,
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
gen_kwargs = dict(
|
| 154 |
-
**inputs,
|
| 155 |
-
max_new_tokens=max_new_tokens,
|
| 156 |
-
temperature=float(temperature),
|
| 157 |
-
do_sample=True, # 与原逻辑一致; 若要采样流式把这里改为 True
|
| 158 |
-
top_k=5,
|
| 159 |
-
top_p=0.9,
|
| 160 |
-
streamer=streamer,
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# 子线程启动生成;主线程消费 streamer
|
| 164 |
-
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
| 165 |
-
thread.start()
|
| 166 |
-
|
| 167 |
-
buffer = "" # 累积完整输出(含思考)
|
| 168 |
-
emitted = "" # 已对外输出的“重写提示词”部分
|
| 169 |
-
already_stripped_newline = False
|
| 170 |
-
|
| 171 |
-
try:
|
| 172 |
-
for piece in streamer:
|
| 173 |
-
buffer += piece
|
| 174 |
-
part = buffer.split('assistant')[-1]
|
| 175 |
-
delta = part[len(emitted):]
|
| 176 |
-
if delta:
|
| 177 |
-
emitted = part
|
| 178 |
-
yield emitted # 将中间结果送给前端
|
| 179 |
-
finally:
|
| 180 |
-
thread.join()
|
| 181 |
-
|
| 182 |
-
# 如果始终没等到第二个 think>,回退到原始 prompt
|
| 183 |
-
# if emitted.strip() == "":
|
| 184 |
-
# yield replace_single_quotes(org_prompt_cot)
|
| 185 |
-
try:
|
| 186 |
-
assert emitted.count("think>") == 2
|
| 187 |
-
prompt_cot = emitted.split("think>")[-1]
|
| 188 |
-
if prompt_cot.startswith("\n"):
|
| 189 |
-
prompt_cot = prompt_cot[1:]
|
| 190 |
-
prompt_cot = emitted.split('assistant')[-1] + '\n \n Recaption:'+replace_single_quotes(prompt_cot)
|
| 191 |
-
# prompt_cot = replace_single_quotes(prompt_cot)
|
| 192 |
-
yield prompt_cot
|
| 193 |
-
except Exception as e:
|
| 194 |
-
prompt_cot = org_prompt_cot
|
| 195 |
-
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
|
| 196 |
-
yield prompt_cot
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
# -------------------------
|
| 201 |
# Gradio app helpers
|
| 202 |
# -------------------------
|
|
@@ -223,32 +138,6 @@ def ensure_enhancer(state, model_path, device_map, torch_dtype):
|
|
| 223 |
return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
|
| 224 |
return state
|
| 225 |
|
| 226 |
-
def stream_single(prompt, sys_prompt, temperature, max_new_tokens, device,
|
| 227 |
-
model_path, device_map, torch_dtype, state):
|
| 228 |
-
if not prompt or not str(prompt).strip():
|
| 229 |
-
yield "", "请先输入提示词。", state
|
| 230 |
-
return
|
| 231 |
-
|
| 232 |
-
t0 = time.time()
|
| 233 |
-
state = ensure_enhancer(state, model_path, device_map, torch_dtype)
|
| 234 |
-
enhancer = state["enhancer"]
|
| 235 |
-
|
| 236 |
-
emitted = ""
|
| 237 |
-
try:
|
| 238 |
-
for chunk in enhancer.predict_stream(
|
| 239 |
-
prompt_cot=prompt,
|
| 240 |
-
sys_prompt=sys_prompt,
|
| 241 |
-
temperature=temperature,
|
| 242 |
-
max_new_tokens=max_new_tokens,
|
| 243 |
-
device=device
|
| 244 |
-
):
|
| 245 |
-
emitted = chunk
|
| 246 |
-
info = f"已接收 {len(emitted)} 字符,用时 {time.time()-t0:.2f}s"
|
| 247 |
-
yield emitted, info, state
|
| 248 |
-
# 结束时再给一次最终状态(可选)
|
| 249 |
-
yield emitted, f"完成。总耗时 {time.time()-t0:.2f}s", state
|
| 250 |
-
except Exception as e:
|
| 251 |
-
yield "", f"推理失败:{e}", state
|
| 252 |
|
| 253 |
def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
|
| 254 |
model_path, device_map, torch_dtype, state):
|
|
@@ -295,11 +184,12 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
|
|
| 295 |
value=DEFAULT_MODEL_PATH,
|
| 296 |
placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
|
| 297 |
)
|
| 298 |
-
device_map =
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
| 303 |
torch_dtype = gr.Dropdown(
|
| 304 |
choices=["bfloat16", "float16", "float32"],
|
| 305 |
value="bfloat16",
|
|
|
|
| 53 |
self,
|
| 54 |
prompt_cot,
|
| 55 |
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
|
| 56 |
+
temperature=0.1,
|
| 57 |
top_p=1.0,
|
| 58 |
max_new_tokens=2048,
|
| 59 |
+
device="cuda",
|
| 60 |
):
|
| 61 |
org_prompt_cot = prompt_cot
|
| 62 |
try:
|
|
|
|
| 112 |
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
|
| 113 |
|
| 114 |
return prompt_cot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
# -------------------------
|
| 116 |
# Gradio app helpers
|
| 117 |
# -------------------------
|
|
|
|
| 138 |
return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
|
| 139 |
return state
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
|
| 143 |
model_path, device_map, torch_dtype, state):
|
|
|
|
| 184 |
value=DEFAULT_MODEL_PATH,
|
| 185 |
placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
|
| 186 |
)
|
| 187 |
+
device_map = "cuda"
|
| 188 |
+
# gr.Dropdown(
|
| 189 |
+
# choices=["auto", "cuda", "cpu"],
|
| 190 |
+
# value="auto",
|
| 191 |
+
# label="device_map(模型加载映射)"
|
| 192 |
+
# )
|
| 193 |
torch_dtype = gr.Dropdown(
|
| 194 |
choices=["bfloat16", "float16", "float32"],
|
| 195 |
value="bfloat16",
|