Skip to content

Commit

Permalink
LightGBM working
Browse files Browse the repository at this point in the history
  • Loading branch information
shalinis602 committed Jul 23, 2024
1 parent 92efd4f commit 2f662be
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 16 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
## Overview

This project is my implementation of [this research paper](https://ernest-bonat.medium.com/rna-seq-gene-expression-classification-using-machine-learning-algorithms-de862e60bfd0#4592) from scratch. The objective of this project is to identify gene expression patterns associated with different conditions or diseases, leveraging advanced data processing and model training techniques. RNA sequencing (RNA-Seq) is a teachnique used to quantify and analyze gene expression levels across different conditions or samples. The classification of RNA-Seq data can be used to identify which genes are differently expressed between healthy and diseased samples, or between different diseased states, thus aiding in the diagnosis, treatment, and understanding of various diseases and conditions.
The objective of this project is to identify gene expression patterns associated with different conditions or diseases, leveraging advanced data processing and model training techniques. RNA sequencing (RNA-Seq) is a teachnique used to quantify and analyze gene expression levels across different conditions or samples. The classification of RNA-Seq data can be used to identify which genes are differently expressed between healthy and diseased samples, or between different diseased states, thus aiding in the diagnosis, treatment, and understanding of various diseases and conditions.

## Table of Contents

- [Overview](#overview)
- [Table of Contents](#table-of-contents)
- [Dataset](#dataset)
- [Project Structure](#project-structure)
- [Running the project](#running-the-project)
- [Running the Project](#running-the-project)
- [Models Used](#models-used)
- [Results](#results)
- [Contributing](#contributing)
- [Acknowledgements](#acknowledgements)
- [**References**](#references)

## Dataset

Expand Down Expand Up @@ -163,5 +164,5 @@ If you find any issues or have suggestions for improvements or expanding the pro
4. Push to the branch (`git push origin feature-branch`).
5. Create a new Pull Request.

## **Acknowledgements**

## **References**
1. [RNA-Seq Gene Expression Classification Using Machine Learning Algorithms](https://ernest-bonat.medium.com/rna-seq-gene-expression-classification-using-machine-learning-algorithms-de862e60bfd0)
41 changes: 41 additions & 0 deletions results/svm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Validation Accuracy: 0.98
Validation Classification Report:
precision recall f1-score support

BRCA 0.97 1.00 0.98 28
COAD 1.00 1.00 1.00 9
KIRC 1.00 0.90 0.95 10
LUAD 0.92 0.92 0.92 13
PRAD 1.00 1.00 1.00 21

accuracy 0.98 81
macro avg 0.98 0.96 0.97 81
weighted avg 0.98 0.98 0.98 81

Validation Confusion Matrix:
[[28 0 0 0 0]
[ 0 9 0 0 0]
[ 0 0 9 1 0]
[ 1 0 0 12 0]
[ 0 0 0 0 21]]

Test Accuracy: 0.99
Test Classification Report:
precision recall f1-score support

BRCA 0.96 1.00 0.98 27
COAD 1.00 1.00 1.00 8
KIRC 1.00 1.00 1.00 15
LUAD 1.00 0.95 0.97 19
PRAD 1.00 1.00 1.00 11

accuracy 0.99 80
macro avg 0.99 0.99 0.99 80
weighted avg 0.99 0.99 0.99 80

Test Confusion Matrix:
[[27 0 0 0 0]
[ 0 8 0 0 0]
[ 0 0 15 0 0]
[ 1 0 0 18 0]
[ 0 0 0 0 11]]
25 changes: 15 additions & 10 deletions src/models/light_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def pickle_serialize_object(filename, obj):
def main():
# Deserialize the input
input_dir = 'data/processed'
X_train_pca = pickle_deserialize_object(os.path.join(input_dir, 'X_train_pca.pkl'))
y_train_resampled = pickle_deserialize_object(os.path.join(input_dir, 'y_train_resampled.pkl'))

input_dir2 = 'data/processed/transformed'
X_train_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_train_pca.pkl'))
X_val_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_val_pca.pkl'))
X_test_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_test_pca.pkl'))

Expand All @@ -37,12 +37,16 @@ def main():

# Define parameter grid for GridSearchCV
param_grid = {
'n_estimators': [50, 100],
'learning_rate': [0.01],
'num_leaves': [31],
'max_depth': [10, 20],
'min_child_samples': [100, 200],
'force_col_wise': [True]
'min_child_samples': [10],
'min_split_gain': [0.01],
'learning_rate': [0.1,],
'max_depth': [-1, 10, 20],
'reg_alpha': [0, 0.1, 0.5],
'reg_lambda': [0, 0.1, 0.5],
'subsample': [0.8],
'colsample_bytree': [0.8],
'n_estimators': [100]
}

# Initialize and fit LGBMClassifier with GridSearchCV
Expand All @@ -60,25 +64,26 @@ def main():
val_confusion_matrix = confusion_matrix(y_val, y_val_pred)

# Evaluate on test set
'''
y_test_pred = best_lgbm.predict(X_test_pca)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_classification_report = classification_report(y_test, y_test_pred)
test_confusion_matrix = confusion_matrix(y_test, y_test_pred)

'''
# Write results to a file
output_filename = 'results/lightgbm.txt'
output_filename = 'results/light_gbm.out'
with open(output_filename, 'w') as f:
f.write(f"Validation Accuracy: {val_accuracy:.2f}\n")
f.write("Validation Classification Report:\n")
f.write(val_classification_report + '\n')
f.write("Validation Confusion Matrix:\n")
f.write(str(val_confusion_matrix) + '\n\n')

'''
f.write(f"Test Accuracy: {test_accuracy:.2f}\n")
f.write("Test Classification Report:\n")
f.write(test_classification_report + '\n')
f.write("Test Confusion Matrix:\n")
f.write(str(test_confusion_matrix) + '\n')

'''
if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/models/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def pickle_serialize_object(filename, obj):
def main():
# Deserialize the input
input_dir = 'data/processed'
X_train_pca = pickle_deserialize_object(os.path.join(input_dir, 'X_train_pca.pkl'))
y_train_resampled = pickle_deserialize_object(os.path.join(input_dir, 'y_train_resampled.pkl'))

input_dir2 = 'data/processed/transformed'
X_train_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_train_pca.pkl'))
X_val_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_val_pca.pkl'))
X_test_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_test_pca.pkl'))

Expand Down
Empty file removed src/models/svm.py
Empty file.
82 changes: 82 additions & 0 deletions src/models/svm_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import sys
import datetime
import pickle
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV

# Generate a timestamp for this run
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file = f"logs/svm_{timestamp}.out"

# Redirect stdout and stderr to the log file
sys.stdout = open(f'{log_file}', 'a')
sys.stderr = open(f'{log_file}', 'a')

def pickle_deserialize_object(filename):
with open(filename, 'rb') as f:
return pickle.load(f)

def pickle_serialize_object(filename, obj):
with open(filename, 'wb') as f:
pickle.dump(obj, f)

def main():
# Deserialize the input
input_dir = 'data/processed'
X_train_pca = pickle_deserialize_object(os.path.join(input_dir, 'X_train_pca.pkl'))
y_train_resampled = pickle_deserialize_object(os.path.join(input_dir, 'y_train_resampled.pkl'))

input_dir2 = 'data/processed/transformed'
X_val_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_val_pca.pkl'))
X_test_pca = pickle_deserialize_object(os.path.join(input_dir2, 'X_test_pca.pkl'))

input_dir3 = 'data/processed/split_data'
y_val = pickle_deserialize_object(os.path.join(input_dir3, 'y_val.pkl'))
y_test = pickle_deserialize_object(os.path.join(input_dir3, 'y_test.pkl'))

# Define parameter grid for GridSearchCV
param_grid = {
'C': [0.1, 1, 10],
'gamma': [0.001, 0.01, 0.1],
'kernel': ['rbf', 'linear']
}

# Initialize and fit SVM with GridSearchCV
svm = SVC(random_state=1)
grid_search = GridSearchCV(svm, param_grid, cv=3, n_jobs=-1, verbose=1)
grid_search.fit(X_train_pca, y_train_resampled)

# Get the best estimator
best_svm = grid_search.best_estimator_

# Evaluate on validation set
y_val_pred = best_svm.predict(X_val_pca)
val_accuracy = accuracy_score(y_val, y_val_pred)
val_classification_report = classification_report(y_val, y_val_pred)
val_confusion_matrix = confusion_matrix(y_val, y_val_pred)

# Evaluate on test set
y_test_pred = best_svm.predict(X_test_pca)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_classification_report = classification_report(y_test, y_test_pred)
test_confusion_matrix = confusion_matrix(y_test, y_test_pred)

# Write results to a file
output_filename = 'results/svm.txt'
with open(output_filename, 'w') as f:
f.write(f"Validation Accuracy: {val_accuracy:.2f}\n")
f.write("Validation Classification Report:\n")
f.write(val_classification_report + '\n')
f.write("Validation Confusion Matrix:\n")
f.write(str(val_confusion_matrix) + '\n\n')

f.write(f"Test Accuracy: {test_accuracy:.2f}\n")
f.write("Test Classification Report:\n")
f.write(test_classification_report + '\n')
f.write("Test Confusion Matrix:\n")
f.write(str(test_confusion_matrix) + '\n')

if __name__ == "__main__":
main()

0 comments on commit 2f662be

Please sign in to comment.