Dementia is the loss of cognitive function, which affects how a person can use language and communicate.
Early detection and intervention are crucial for improving the quality of life for affected individuals.
This project aims to harness the power of natural language processing and deep learning to diagnose dementia at an
early stage based on speech patterns.
By utilizing GPT-3 text embeddings, we seek to develop an innovative tool that can assist in the early
detection of dementia.
Explore the docs »
Report Bug
·
Request Feature
Table of Contents
This project was created during my position as a research assistant at TECO research group at Karlsruhe Institute of Technology (KIT). It involves several steps:
- Transcription using Whisper
- Creation of GPT-3 text embeddings to leverage contextual depth from the transcribed text
- Training machine learning models using the extracted features
- Prediction on test data
To get a local copy up and running follow these steps.
To gain access to the data used in this project, you must first join as a DementiaBank member.
For simplicity, we used the dataset provided for the ADReSSo-challenge which has been balanced with respect to age and gender in order to eliminate potential confunding and bias.
In our case, the downloaded ADReSSo audio files had an incompatible format to transcribe them with Whisper. Therefore, we first had to format them with ffmepg. Since we cannot reformat and replace the files at the same time, we have to save them temporarily and then replace the old files:
find . -name '*.wav' -exec sh -c 'mkdir -p fix && ffmpeg -i "$0" "fix/$(basename "$0")"' {} \;
Now replace the original audio files with the formatted files in the fix
folder, which can then be deleted.
- Get an OpenAI API Key at https://platform.openai.com/account/api-keys.
- Set an environment variable 'OPENAI_API_KEY' (replace ~/.zshrc with ~/.bashrc if you use Bash):
echo "export OPENAI\_API\_KEY='your key'" | cat >> ~/.zshrc
- Clone the repo
git clone https://github.com/probstlukas/gpt3-dementia-detection.git
- Install required Python packages
pip install -r requirements.txt
Before running the program, make sure that the default settings (logging level, embedding engine, directory paths, ...) in config.py
suit you.
Then start the program by running main.py
.
If the data has not yet been transcribed, confirm with yes
, select your preferred Whisper model and wait until the transcription process is complete.
The embeddings are created separately for training and test data. train_embeddings.csv
also contains the MMSE score and the diagnosis label for each audio file.
It is not necessary to scale the embeddings before using them. They are already normalized and are in the vector space with a certain distribution.
In this step, machine learning models are trained and evaluated using the provided embeddings. For comparison with other, more complex classifiers, we create a dummy classifier that makes predictions that ignore the input features. This gives us a baseline performance (like flipping a coin, i.e. about 50%).
We use three different models: Support Vector Machine, Logistic Regression, and Random Forest.
The classification process can be divided into two parts:
- Perform K-fold cross-validation on the training set: split it into K equal partitions, where each partition is divided into training and validation set.
- Hyperparameter optimization
- Record model performance (accuracy, precision, recall, f1-score)
- Visualize results: one plot for each metric per model, resulting in a total of 12 plots. For example, this plot relates to the accuracy metric for the logistic regression model:
- Train each model on the entire training set with the best hyperparameters.
- Predict labels on the test data using the trained model.
- Evaluate performance by comparing the results to real medical diagnoses.
- Record trained model sizes.
All processed data and results are stored in the configured directories specified in config.py
.
These results were obtained with the following configuration:
- Whisper model:
base
- Embedding engine:
text-embedding-ada-002
- Number of splits for the K-Fold CV:
K=10
Model | Size |
---|---|
SVC | 1441162 B |
LR | 13007 B |
RF | 82947 B |
Total | 1537116 B |
Set | Model | Accuracy | Precision | Recall | F1 |
---|---|---|---|---|---|
Train | SVC | 0.779 (0.084) | 0.839 (0.048) | 0.779 (0.084) | 0.772 (0.084) |
Train | LR | 0.81 (0.085) | 0.847 (0.07) | 0.81 (0.085) | 0.8 (0.091) |
Train | RF | 0.804 (0.052) | 0.839 (0.044) | 0.804 (0.052) | 0.793 (0.053) |
Test | SVC | 0.779 | 0.839 | 0.779 | 0.772 |
Test | LR | 0.81 | 0.847 | 0.81 | 0.8 |
Test | RF | 0.804 | 0.839 | 0.804 | 0.793 |
Test | Dummy | 0.425 |
Our results show that GPT-3 text embeddings can be used to reliably distinguish individuals with Alzheimer's disease from healthy individuals from the control group, just by analyzing their speech behavior.
Distributed under the MIT License. See LICENSE.txt
for more information.