aladdin1995 commited on
Commit
327846c
·
verified ·
1 Parent(s): e219ec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -3
app.py CHANGED
@@ -250,7 +250,27 @@ def stream_single(prompt, sys_prompt, temperature, max_new_tokens, device,
250
  except Exception as e:
251
  yield "", f"推理失败:{e}", state
252
 
 
 
 
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # 示例数据
255
  test_list_zh = [
256
  "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
@@ -313,12 +333,18 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
313
  out_text = gr.Textbox(label="重写结果", lines=10)
314
  out_info = gr.Markdown("准备就绪。")
315
 
 
 
 
 
 
 
316
  run_btn.click(
317
- stream_single,
318
  inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
319
- model_path, device_map, torch_dtype, state],
320
  outputs=[out_text, out_info, state]
321
- )
322
 
323
  gr.Markdown(
324
  "提示:如有任何问题可email联系:linqing1995@buaa.edu.cn"
 
250
  except Exception as e:
251
  yield "", f"推理失败:{e}", state
252
 
253
+ def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
254
+ model_path, device_map, torch_dtype, state):
255
+ if not prompt or not str(prompt).strip():
256
+ return "", "请先输入提示词。", state
257
 
258
+ t0 = time.time()
259
+ state = ensure_enhancer(state, model_path, device_map, torch_dtype)
260
+ enhancer = state["enhancer"]
261
+ try:
262
+ out = enhancer.predict(
263
+ prompt_cot=prompt,
264
+ sys_prompt=sys_prompt,
265
+ temperature=temperature,
266
+ max_new_tokens=max_new_tokens,
267
+ device=device
268
+ )
269
+ dt = time.time() - t0
270
+ return out, f"耗时:{dt:.2f}s", state
271
+ except Exception as e:
272
+ return "", f"推理失败:{e}", state
273
+
274
  # 示例数据
275
  test_list_zh = [
276
  "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
 
333
  out_text = gr.Textbox(label="重写结果", lines=10)
334
  out_info = gr.Markdown("准备就绪。")
335
 
336
+ # run_btn.click(
337
+ # stream_single,
338
+ # inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
339
+ # model_path, device_map, torch_dtype, state],
340
+ # outputs=[out_text, out_info, state]
341
+ # )
342
  run_btn.click(
343
+ run_single,
344
  inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
345
+ model_path, device_map, torch_dtype, state],
346
  outputs=[out_text, out_info, state]
347
+ )
348
 
349
  gr.Markdown(
350
  "提示:如有任何问题可email联系:linqing1995@buaa.edu.cn"