Skip to content

4. Detailed Model Architecture

stutimishra7 edited this page Feb 6, 2024 · 4 revisions

The Architecture of RCoxNet

RCoxNet integrates Cox proportional hazards regression with a deep neural network that includes a random walk with restart. The RCoxNet architecture consists of (1) a layer for random walk with restart (RWR) scores, (2) multiple hidden layers, (3) a clinical layer incorporating Age, TMB score, and MSI score, and (4) a Cox layer.

(1). RWR Layer: The RWR layer functions as an input layer in RCoxNet, incorporating RWR score data from n patient samples involving p genes. In RWR-based analysis, only genes that have mutations in at least one patient are taken into account within the RWR layer and runs the random walk on the regulatory gene network from string database and gives the relevance score of each gene with respect to the seed that is an intermediate output. This list look below where P1, P2, P3 signifies patient id and other columns specifies genes mutated from TCGA-Data:

Once random walk is run for seeds, entire matrix of unique Random walk score for each gene present in the mutated gene list and for each patient is obtained that has:

Number of rows: number of unique patients

Number of columns: number of genes present in the Cosmic data-base genes list.

This matrix is further split into train, test and validation in further process.

(2). Hidden Layers: Hidden layers enable neural networks to learn and represent complex relationships in data by capturing hierarchical features, introducing non-linearities, and automatically extracting relevant features.

(3). Clinical Layer: The clinical layer in RCoxNet incorporates crucial clinical covariates such as age, TMB (tumor mutational burden) score, and MSI (microsatellite instability) score. These attributes are known to significantly impact prognostic prediction. By directly integrating these covariates into the model's output layer alongside the highest-level genomic data representation, RCoxNet ensures a separate consideration of the effects of genomic and clinical data. This design enhances the model's prognostic accuracy by preventing the potential oversight of clinically relevant features. This comprehensive training process ensures that the RCoxNet is trained effectively on survival data and can provide accurate predictions and meaningful insights.

To the matrices obtained above, following information are appended so that it can be fed into the proposed architecture. The information includes:

  1. Overall Survival Months
  2. Overall Survival Status
  3. Age
  4. TMB-Tumor Mutation Burden
  5. MSI Sensor Score

This is quite beneficial for Cox-PH model for finding out prognostic Index. This statistical method is employed to examine the correlation between an individual's survival time and one or more predictor variables when the proportional hazards assumption is true, the Cox Proportional-Hazards model is especially helpful because it does not make any assumptions about the survival time distribution. The Cox model makes the assumption that the hazard ratios between any two people remain constant across time, and the hazard function depicts the instantaneous failure rate at any given time. The output of this module is prognostic index that act as input to the last model that is survival analysis.

(4). Cox Layer: This is the output layer with one linear node. The score produced by the node is termed as Prognostic Index (PI). This layer has no bias following the Cox-model. Prognostic Index from this layer is then used to stratify patients into two groups of high and low-risk patients based on the median value.

Prognostic index (PI): This index helps predict how long individuals in our study might survive based on various factors. We then divided people into two groups - one with a higher risk of an event (like death) and another with a lower risk, using their PI values. To check if this division was meaningful, we used a test called the log-rank test. This test compares how long people in the high-risk and low-risk groups survive to see if there's a significant difference. The test gives us a p-value, and if it's small (typically less than 0.05), it suggests that our prognostic index is good at separating people into groups with different survival experiences. This helps us understand who might be at a higher or lower risk in our study, which is important for making informed decisions about treatments and care.

Objective Function:

To conduct Cox-proportional hazards regression within the Cox layer, RCoxNet formulates the objective function by employing the average negative log partial likelihood and L2 regularization. The negative log partial likelihood (NLP) is a term commonly used in survival analysis, particularly when dealing with censored data. The negative log partial likelihood is minimized to find the parameter values best fit the observed survival data. The L2 regularization term penalizes large values of the coefficients to prevent overfitting.

                                             lµ=-1nEiE(ilδ-logjR(Ti)expjl + α(||µ||2)                                                   

In this equation, µ={ δ,W} are the set of parameters that needs to be optimized where is Cox proportional hazard coefficient that is the weight between last hidden layer and Cox layer for output. W is a union of weight matrices across the architecture before the Cox layer. il is the very last hidden layer output with the clinical information. ne is a total number of uncensored events & is the regularization hyper-parameter. R(Ti) = { i|Ti >= t} is a set of all samples at risk of failure at time t.

Model Training & Hyperparameter Tuning

The data frame was split into 70%, 10%, and 20% to create training, validation, and testing data, respectively. We used the Adam optimizer to train the model. Adam optimization is a technique used to approximate first-order gradients during the optimization process. The early stopping was used to stop training when the validation loss stopped improving. We used grid search for hyperparameter optimization, exploring exhaustive combinations of L2 regularization (0.1, 0.01, 0.005, 0.001) and learning rates (0.03, 0.01, 0.001, 0.0075) to identify the optimal set. Throughout this process, the model undergoes training and evaluation for each combination, to identify the hyperparameters associated with the lowest validation loss. Subsequently, these optimized hyperparameters are chosen to build the final model. The model was retrained over ten times for model performance reproducibility.

Training RCoxNet:

i. Initialization:

  • Initialize the RCoxNet model with specified architecture parameters such as the number of input nodes (In_Nodes), hidden nodes ('Hidden_Nodes'), and output nodes (Out_Nodes).
  • Define the hyperparameters, including the initial learning rate (Initial_Learning_Rate), L2 regularization lambda (L2_Lambda), and the number of training epochs (Num_EPOCHS).
  • Set up the optimizer, in this case using the Adam optimizer, to update the model parameters during training.

ii. Load Data:

  • Load the training data (x_train, age_train, msi_train, tmb_train, ytime_train, yevent_train) and validation data (x_valid, age_valid, msi_valid, tmb_valid, ytime_valid, yevent_valid) using the load_data function.
  • The data includes genomic inputs (x), age, MSI (Microsatellite Instability) score, TMB (Tumor Mutational Burden) score, survival time (ytime), and censoring status (yevent).

iii. Grid Search for Hyperparameters:

  • Perform a grid search over hyperparameters (L2 regularization and learning rate) using the training and validation data.
  • Train the RCoxNet for each combination of hyperparameters and evaluate the model's performance on the validation set.
  • Choose the hyperparameters that result in the best validation loss.

iv. Training Loop:

  • Execute the training loop for the specified number of epochs (Num_EPOCHS).
  • For each epoch:
    • Set the model to training mode and zero out the gradients (opt.zero_grad()).
    • Forward pass: Pass the training data through the model (net) to obtain predictions.
    • Compute the negative partial log-likelihood loss using the neg_par_log_likelihood function.
    • Backward pass: Compute gradients and update the model parameters using the Adam optimizer (opt.step()).

v. Model Evaluation:

  • After training, evaluate the model on the test data (x_test, age_test, msi_test, tmb_test, ytime_test, yevent_test).
  • Calculate the negative partial log-likelihood and concordance index for survival analysis using the trained model.

vi. Optimal Model Selection:

  • Select the model with the best hyperparameters based on the performance on the validation set.
  • Use the selected model for further evaluation on the test set.

vii. Model Interpretation:

  • Interpret the trained RCoxNet model to gain insights into feature importance and contributions.
  • Save the model weights and node values for further analysis.

viii. Results:

  • Report the key results, including optimal hyperparameters, model performance on the test set, and any additional insights gained from model interpretation.

ix. Save Model:

  • Save the trained RWR-PASNet model, including its state dictionary, for future use or deployment.

Explanation of RCoxNet parameters:

RWR_data_bulder.py

  1. load_data():

    • Reads clinical patient, clinical sample, and mutation data from three different files (data_clinical_patient.txt, data_clinical_sample.txt, and data_mutations.txt).
    • Returns the dataframes.
  2. clean_data(df_data_mut):

    • Cleans mutation data by extracting the prefix from the 'Tumor_Sample_Barcode' column and saving it to a new CSV file named "clean_tcga.csv".
    • Modifies the 'Tumor_Sample_Barcode' column in the original dataframe.
  3. merge_data(df_clinical_patient, df_clinical_sample, df_data_mut_final):

    • Merges clinical patient, clinical sample, and cleaned mutation data based on patients ID.
    • Returns the merged and filtered dataframe.
  4. get_all_mutated_genes(df_data_mut):

    • Groups mutation data by 'Tumor_Sample_Barcode' and returns a dataframe with unique mutated genes for each sample.
  5. get_deleted_genes(df_data_mut):

    • Reads data from "output_9606.protein.links.full.v11.5.txt" and identifies deleted genes.
    • Returns a list of deleted genes and the original list of genes.
  6. remove_items(test_list, item):

    • Removes specified items from a list.
  7. calculate_seed(df_data_mut_patient_all_mutated, df_new, del_genes_all):

    • Filters out deleted genes from the mutation data and saves the result to "patient_id_updated.csv".
    • Groups the updated mutation data by 'Tumor_Sample_Barcode' and returns a dataframe.
  8. calculate_sg_score(list_of_genes, list_of_mutated_genes):

  • Calculates a score for each gene based on its presence in the mutation data.
  • Saves the results to "score_sg_all_tcga_brca_all.csv".
  • Returns lists of mutated genes and their scores.
  • scores has not been use in code
  1. run_r_script(input_csv_sg, input_csv_seed, output_csv):

    • Executes an R script using rpy2, performing calculations based on input CSV files and saving the results to an output CSV file.
    • Returns generated CSV file and dataframe.
  2. Extract(lst):

    • Extracts the first element from each list.
  3. process_and_clean_data(df_data_mut_patient_all_mutated, df_data_mut_patient_all_data_new, csv_output_path, data):

    • Processes and cleans additional columns from the mutation data, adds new columns, and saves the cleaned data to a CSV file.
    • Returns the cleaned dataframe.
  4. split_data(data, train_ratio, test_ratio, validation_ratio, random_state):

    • Splits the data into training, testing, and validation sets based on specified ratios.
    • Prints the size of each set and returns the sets.
  5. save_data_to_csv(data, filename):

    • Saves the provided dataframe to a CSV file.
  6. call_funs():

    • Calls all the defined functions in sequence to load, clean, merge, calculate, and process data.
    • Saves the training, testing, and validation sets to CSV files.

dataloader.py

  • load_sorted_data(file_path, tensor_dtype):

  • Calls the sort_genomic_clinical_data function to obtain sorted data.

  • Converts the sorted data to PyTorch tensors with the specified tensor_dtype.

  • Returns the following PyTorch tensors: X: Genomic inputs. YTIME: Survival time. YEVENT: Censoring status. AGE: Age data. MSI: MSI data. TMB: TMB data.

model.py

  1. RCoxNet Class:
    • Constructor (__init__):

      • Initializes the neural network architecture.
      • Parameters:
        • input_nodes: Number of input nodes (features) for genomic data.
        • hidden_nodes1: Number of nodes in the first hidden layer.
        • hidden_nodes2: Number of nodes in the second hidden layer.
        • output_nodes: Number of nodes in the output layer.
    • Attributes:

      • tanh: The hyperbolic tangent activation function (nn.Tanh()).

      • rwr_layer: The linear layer for the Random Walk Restart (RWR) method, transforming genomic input.

      • hidden_layer1: The first hidden linear layer.

      • hidden_layer2: The second hidden linear layer.

      • cox_layer: The linear layer for Cox Proportional Hazard model, combining the RWR output and additional features (age, MSI, TMB).

    • Methods (forward):

      • Defines the forward pass of the neural network.
      • Takes genomic data (x_genomic), MSI data (x_msi), TMB data (x_tmb), and age data (x_age) as inputs.
      • Applies the tanh activation function to the RWR layer and the two hidden layers.
      • Concatenates the hidden layer 2 output with additional features.
      • Passes the combined features through the Cox Proportional Hazard layer.
      • Returns the output of the Cox Proportional Hazard layer.

This architecture is designed for survival analysis, combining genomic data with additional features and using the Random Walk Restart method in the initial layers. The final layer produces the output for the Cox Proportional Hazard model.

train.py

  1. Trains a Cox proportional hazards model with a neural network structure.

    Parameters:

    • train_x, eval_x: Training and evaluation input features.

    • train_age, eval_age: Age information for training and evaluation.

    • train_ytime, eval_ytime: Time-to-event for training and evaluation.

    • train_yevent, eval_yevent: Event indicator for training and evaluation.

    • train_msi, eval_msi: Microsatellite instability for training and evaluation.

    • train_tmb, eval_tmb: Tumor mutational burden for training and evaluation.

    • In_Nodes, hidden_nodes1, hidden_nodes2, Out_Nodes: Neural network architecture parameters.

    • Learning_Rate: Learning rate for optimization.

    • L2: L2 regularization parameter.

    • Num_Epochs: Number of training epochs.

    Returns:

    • Tuple containing training loss, evaluation loss, training concordance index, and evaluation concordance index.