Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +16 -24
orpheus-tts/engine_class.py
CHANGED
|
@@ -86,33 +86,25 @@ class OrpheusModel:
|
|
| 86 |
if voice not in self.engine.available_voices:
|
| 87 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
| 88 |
|
| 89 |
-
def _format_prompt(self, prompt, voice="
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
else:
|
| 94 |
-
return f"<custom_token_3>{prompt}<custom_token_4><custom_token_5>"
|
| 95 |
else:
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
|
| 108 |
-
all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
|
| 109 |
-
prompt_string = self.tokenizer.decode(all_input_ids[0])
|
| 110 |
-
return prompt_string
|
| 111 |
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [128258], repetition_penalty=1.3):
|
| 116 |
prompt_string = self._format_prompt(prompt, voice)
|
| 117 |
print(f"DEBUG: Original prompt: {prompt}")
|
| 118 |
print(f"DEBUG: Formatted prompt: {prompt_string}")
|
|
|
|
| 86 |
if voice not in self.engine.available_voices:
|
| 87 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
| 88 |
|
| 89 |
+
def _format_prompt(self, prompt, voice="Sophie", model_type="larger"):
|
| 90 |
+
# Use Kartoffel model format based on documentation
|
| 91 |
+
if voice:
|
| 92 |
+
full_prompt = f"{voice}: {prompt}"
|
|
|
|
|
|
|
| 93 |
else:
|
| 94 |
+
full_prompt = prompt
|
| 95 |
+
|
| 96 |
+
# Kartoffel model token format
|
| 97 |
+
start_token = torch.tensor([[128259]], dtype=torch.int64)
|
| 98 |
+
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
|
| 99 |
+
|
| 100 |
+
input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
|
| 101 |
+
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
|
| 102 |
+
|
| 103 |
+
prompt_string = self.tokenizer.decode(modified_input_ids[0])
|
| 104 |
+
return prompt_string
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
+
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
|
|
|
|
|
|
|
| 108 |
prompt_string = self._format_prompt(prompt, voice)
|
| 109 |
print(f"DEBUG: Original prompt: {prompt}")
|
| 110 |
print(f"DEBUG: Formatted prompt: {prompt_string}")
|