Alon Albalak commited on
Commit
fdc1169
·
1 Parent(s): 6cef7dd

fix minor bug

Browse files
Files changed (1) hide show
  1. src/models/llm_manager.py +3 -3
src/models/llm_manager.py CHANGED
@@ -78,8 +78,8 @@ class LLMManager:
78
 
79
  with torch.no_grad():
80
  outputs = self.model.generate(
81
- inputs.input_ids,
82
- attention_mask=inputs.attention_mask,
83
  max_new_tokens=1000,
84
  do_sample=True,
85
  top_p=0.95,
@@ -90,6 +90,6 @@ class LLMManager:
90
  # Move output back to CPU and decode
91
  outputs = outputs.cpu()
92
 
93
- full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
94
  assistant_part = full_response.split("Assistant: ")[-1]
95
  return assistant_part
 
78
 
79
  with torch.no_grad():
80
  outputs = self.model.generate(
81
+ inputs['input_ids'],
82
+ attention_mask=inputs['attention_mask'],
83
  max_new_tokens=1000,
84
  do_sample=True,
85
  top_p=0.95,
 
90
  # Move output back to CPU and decode
91
  outputs = outputs.cpu()
92
 
93
+ full_response = self.tokenizer.decode(outputs[0].cpu(), skip_special_tokens=True)
94
  assistant_part = full_response.split("Assistant: ")[-1]
95
  return assistant_part