All model to have device changed
Browse files- cli/SparkTTS.py +5 -0
cli/SparkTTS.py
CHANGED
|
@@ -49,6 +49,11 @@ class SparkTTS:
|
|
| 49 |
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
|
| 50 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
| 51 |
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def process_prompt(
|
| 54 |
self,
|
|
|
|
| 49 |
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
|
| 50 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
| 51 |
self.model.to(self.device)
|
| 52 |
+
|
| 53 |
+
def to(self, device: torch.device):
|
| 54 |
+
self.device = device
|
| 55 |
+
self.model.to(self.device)
|
| 56 |
+
self.audio_tokenizer.to(self.device)
|
| 57 |
|
| 58 |
def process_prompt(
|
| 59 |
self,
|