From d1905bdb8fc5130ab6fb29c757cb2b4510746a73 Mon Sep 17 00:00:00 2001 From: BBC-Esq Date: Thu, 29 Aug 2024 11:46:25 -0400 Subject: [PATCH] change combobox logic --- settings.py | 34 +++++++++++-- utilities.py | 22 +++++---- whispers2t_batch_gui.py | 84 +++++++++++++++------------------ whispers2t_batch_transcriber.py | 4 +- 4 files changed, 85 insertions(+), 59 deletions(-) diff --git a/settings.py b/settings.py index f5df9b2..5d37e90 100644 --- a/settings.py +++ b/settings.py @@ -1,8 +1,13 @@ from PySide6.QtWidgets import (QGroupBox, QVBoxLayout, QHBoxLayout, QLabel, QComboBox, QSlider) -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, Signal from constants import WHISPER_MODELS +import torch + +from utilities import has_bfloat16_support class SettingsGroupBox(QGroupBox): + device_changed = Signal(str) + def __init__(self, get_compute_and_platform_info_callback, parent=None): super().__init__("Settings", parent) self.get_compute_and_platform_info = get_compute_and_platform_info_callback @@ -17,7 +22,6 @@ def initUI(self): hbox1_layout.addWidget(modelLabel) self.modelComboBox = QComboBox() - self.modelComboBox.addItems(WHISPER_MODELS.keys()) hbox1_layout.addWidget(self.modelComboBox) computeDeviceLabel = QLabel("Device:") @@ -25,6 +29,7 @@ def initUI(self): self.computeDeviceComboBox = QComboBox() hbox1_layout.addWidget(self.computeDeviceComboBox) + self.computeDeviceComboBox.currentTextChanged.connect(self.on_device_changed) formatLabel = QLabel("Output:") hbox1_layout.addWidget(formatLabel) @@ -67,7 +72,7 @@ def initUI(self): self.batchSizeSlider = QSlider(Qt.Horizontal) self.batchSizeSlider.setMinimum(1) self.batchSizeSlider.setMaximum(200) - self.batchSizeSlider.setValue(16) + self.batchSizeSlider.setValue(8) self.batchSizeSlider.setTickPosition(QSlider.TicksBelow) self.batchSizeSlider.setTickInterval(10) batch_size_layout.addWidget(self.batchSizeSlider) @@ -86,4 +91,25 @@ def update_slider_label(self, slider, label): def populateComputeDeviceComboBox(self): available_devices = self.get_compute_and_platform_info() self.computeDeviceComboBox.addItems(available_devices) - self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cpu")) \ No newline at end of file + if "cuda" in available_devices: + self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cuda")) + else: + self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cpu")) + self.update_model_combobox() + + def on_device_changed(self, device): + self.device_changed.emit(device) + self.update_model_combobox() + + def update_model_combobox(self): + current_device = self.computeDeviceComboBox.currentText() + self.modelComboBox.clear() + + for model_name, model_info in WHISPER_MODELS.items(): + if current_device == "cpu" and model_info['precision'] == 'float32': + self.modelComboBox.addItem(model_name) + elif current_device == "cuda": + if model_info['precision'] in ['float32', 'float16']: + self.modelComboBox.addItem(model_name) + elif model_info['precision'] == 'bfloat16' and has_bfloat16_support(): + self.modelComboBox.addItem(model_name) diff --git a/utilities.py b/utilities.py index b5902bb..c700c42 100644 --- a/utilities.py +++ b/utilities.py @@ -11,13 +11,19 @@ def get_compute_and_platform_info(): return available_devices -def get_supported_quantizations(device_type): - types = ctranslate2.get_supported_compute_types(device_type) - filtered_types = [q for q in types if q != 'int16'] - desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8'] - sorted_types = [q for q in desired_order if q in filtered_types] - return sorted_types +# def get_supported_quantizations(device_type): + # types = ctranslate2.get_supported_compute_types(device_type) + # filtered_types = [q for q in types if q != 'int16'] + # desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8'] + # sorted_types = [q for q in desired_order if q in filtered_types] + # return sorted_types +def get_logical_core_count(): + return psutil.cpu_count(logical=True) -def get_physical_core_count(): - return psutil.cpu_count(logical=False) \ No newline at end of file +def has_bfloat16_support(): + if not torch.cuda.is_available(): + return False + + capability = torch.cuda.get_device_capability() + return capability >= (8, 6) \ No newline at end of file diff --git a/whispers2t_batch_gui.py b/whispers2t_batch_gui.py index 44e93af..05da766 100644 --- a/whispers2t_batch_gui.py +++ b/whispers2t_batch_gui.py @@ -1,16 +1,29 @@ +import logging import os import sys +import traceback from pathlib import Path -from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QCheckBox, QLabel, QGroupBox, QMessageBox + +import torch from PySide6.QtCore import Qt -import torch -from utilities import get_compute_and_platform_info, get_supported_quantizations -from whispers2t_batch_transcriber import Worker +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QFileDialog, + QGroupBox, + QHBoxLayout, + QLabel, + QMessageBox, + QPushButton, + QVBoxLayout, + QWidget, +) + +from constants import WHISPER_MODELS from metrics_bar import MetricsBar from settings import SettingsGroupBox -import logging -import traceback -from constants import WHISPER_MODELS +from utilities import has_bfloat16_support +from whispers2t_batch_transcriber import Worker def set_cuda_paths(): try: @@ -26,9 +39,6 @@ def set_cuda_paths(): set_cuda_paths() -def is_nvidia_gpu_available(): - return torch.cuda.is_available() and "nvidia" in torch.cuda.get_device_name(0).lower() - class MainWindow(QWidget): def __init__(self): super().__init__() @@ -36,8 +46,7 @@ def __init__(self): def initUI(self): self.setWindowTitle("chintellalaw.com - for non-commercial use") - initial_height = 400 if is_nvidia_gpu_available() else 370 - self.setGeometry(100, 100, 680, initial_height) + self.setGeometry(100, 100, 680, 400) self.setWindowFlags(self.windowFlags() | Qt.WindowStaysOnTopHint) main_layout = QVBoxLayout() @@ -69,7 +78,8 @@ def initUI(self): fileExtensionsGroupBox.setLayout(fileExtensionsLayout) main_layout.addWidget(fileExtensionsGroupBox) - self.settingsGroupBox = SettingsGroupBox(get_compute_and_platform_info, self) + self.settingsGroupBox = SettingsGroupBox(self.get_compute_and_platform_info, self) + self.settingsGroupBox.device_changed.connect(self.on_device_changed) main_layout.addWidget(self.settingsGroupBox) selectDirLayout = QHBoxLayout() @@ -101,6 +111,16 @@ def closeEvent(self, event): self.metricsBar.stop_metrics_collector() super().closeEvent(event) + def get_compute_and_platform_info(self): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + return devices + + def on_device_changed(self, device): + # You can add any additional logic here if needed when the device is changed + pass + def selectDirectory(self): dirPath = QFileDialog.getExistingDirectory(self, "Select Directory") if dirPath: @@ -122,46 +142,20 @@ def calculate_files_to_process(self): return total_files def perform_checks(self): - model = self.settingsGroupBox.modelComboBox.currentText() device = self.settingsGroupBox.computeDeviceComboBox.currentText() batch_size = self.settingsGroupBox.batchSizeSlider.value() beam_size = self.settingsGroupBox.beamSizeSlider.value() - # Check 1: CPU and non-float32 model - if "float32" not in model.lower() and device.lower() == "cpu": - QMessageBox.warning(self, "Invalid Configuration", - "CPU only supports Float 32 computation. Please select a different Whisper model.") - return False - - # Check 2: CPU with high batch size - if device.lower() == "cpu" and batch_size > 16: - reply = QMessageBox.warning(self, "Performance Warning", - "When using CPU it is generally recommended to use a batch size of no more than 16 " - "otherwise compute time will actually be worse.\n\n" - "Moreover, if you select a Beam Size greater than one, you should reduce the Batch Size accordingly.\n\n" - "For example:\n" - "- If you select a Beam Size of 2 (double the default value of 1) you would reduce the Batch Size (default value 16) by half.\n" - "- If Beam Size is set to 3 you should reduce the Batch Size to 1/3 of the default level, and so on.\n\nClick OK to proceed.", + # Check: CPU with high batch size + if device.lower() == "cpu" and batch_size > 8: + reply = QMessageBox.warning(self, "Warning", + "When using CPU it is generally recommended to use a batch size of no more than 8 " + "otherwise compute could actually be worse. Use at your own risk.", QMessageBox.Ok | QMessageBox.Cancel, QMessageBox.Cancel) if reply == QMessageBox.Cancel: return False - # Check 3: GPU compatibility - # Only perform this check if the device is not CPU - if device.lower() != "cpu": - supported_quantizations = get_supported_quantizations(device) - if "float16" in model.lower() and "float16" not in supported_quantizations: - QMessageBox.warning(self, "Incompatible Configuration", - "Your GPU does not support the selected floating point value (float16). " - "Please make another selection.") - return False - if "bfloat16" in model.lower() and "bfloat16" not in supported_quantizations: - QMessageBox.warning(self, "Incompatible Configuration", - "Your GPU does not support the selected floating point value (bfloat16). " - "Please make another selection.") - return False - return True # All checks passed def processFiles(self): @@ -227,4 +221,4 @@ def workerFinished(self, message): app.setStyle("Fusion") mainWindow = MainWindow() mainWindow.show() - sys.exit(app.exec()) + sys.exit(app.exec()) \ No newline at end of file diff --git a/whispers2t_batch_transcriber.py b/whispers2t_batch_transcriber.py index 6efa773..9191add 100644 --- a/whispers2t_batch_transcriber.py +++ b/whispers2t_batch_transcriber.py @@ -10,9 +10,9 @@ import torch from constants import WHISPER_MODELS -from utilities import get_physical_core_count +from utilities import get_logical_core_count -CPU_THREADS = max(4, get_physical_core_count() - 1) +CPU_THREADS = max(4, get_logical_core_count() - 8) class Worker(QThread): finished = Signal(str)