Update app.py
Browse files
app.py
CHANGED
|
@@ -27,7 +27,7 @@ def load_model(model_id: str = MODEL_ID):
|
|
| 27 |
mdl = AutoModelForCausalLM.from_pretrained(
|
| 28 |
model_id,
|
| 29 |
trust_remote_code=True,
|
| 30 |
-
torch_dtype=
|
| 31 |
device_map="auto",
|
| 32 |
)
|
| 33 |
|
|
@@ -138,7 +138,7 @@ def generate_reply(
|
|
| 138 |
max_new_tokens: int = 128,
|
| 139 |
temperature: float = 1.0,
|
| 140 |
top_p: float = 0.8,
|
| 141 |
-
max_retries: int =
|
| 142 |
) -> str:
|
| 143 |
"""Implements the 4 guardrails from Appendix C.1."""
|
| 144 |
messages = build_hf_messages(system_prompt, history_pairs)
|
|
|
|
| 27 |
mdl = AutoModelForCausalLM.from_pretrained(
|
| 28 |
model_id,
|
| 29 |
trust_remote_code=True,
|
| 30 |
+
torch_dtype=torch.bfloat16,
|
| 31 |
device_map="auto",
|
| 32 |
)
|
| 33 |
|
|
|
|
| 138 |
max_new_tokens: int = 128,
|
| 139 |
temperature: float = 1.0,
|
| 140 |
top_p: float = 0.8,
|
| 141 |
+
max_retries: int = 10,
|
| 142 |
) -> str:
|
| 143 |
"""Implements the 4 guardrails from Appendix C.1."""
|
| 144 |
messages = build_hf_messages(system_prompt, history_pairs)
|