diff --git a/tortoise/api.py b/tortoise/api.py index 296ef14a..a377661a 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -213,7 +213,12 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if not self.device: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + if self.device == 'cpu': + print("Using cpu, expect extremely long processing time!") if self.enable_redaction: self.aligner = Wav2VecAlignment()