Skip to content

4. Detailed Model Architecture

Smruti Panda edited this page Jan 26, 2024 · 4 revisions

The Architecture of RCoxNet:

Cox-RWRNet integrates Cox proportional hazards regression with a deep neural network that includes a random walk with restart. The Cox-RWRNet architecture consists of (1) a layer for random walk with restart (RWR) scores, (2) multiple hidden layers, (3) a clinical layer incorporating Age, TMB score, and MSI score, and (4) a Cox layer.

(1). RWR Layer: The RWR layer functions as an input layer in Cox-RWR, incorporating RWR score data from n patient samples involving p genes. In RWR-based analysis, only genes that have mutations in at least one patient are taken into account within the RWR layer and runs the random walk on the regulatory gene network from string database and gives the relevance score of each gene with respect to the seed that is an intermediate output. This list look below where P1, P2, P3 signifies patient id and other columns specifies genes mutated from TCGA-Data:

image

Once random walk is run for seeds, entire matrix of unique Random walk score for each gene present in the mutated gene list and for each patient is obtained that has:

Number of rows: number of unique patients

Number of columns: number of genes present in the Cosmic data-base genes list.

image

Fig: Unique RWR score matrix of genes

This matrix is further split into train, test and validation in further process.

(2). Hidden Layers: Hidden layers enable neural networks to learn and represent complex relationships in data by capturing hierarchical features, introducing non-linearities, and automatically extracting relevant features.

(3). Clinical Layer: The clinical layer in Cox-RWRNet incorporates crucial clinical covariates such as age, TMB (tumor mutational burden) score, and MSI (microsatellite instability) score. These attributes are known to significantly impact prognostic prediction. By directly integrating these covariates into the model's output layer alongside the highest-level genomic data representation, Cox-RWRNet ensures a separate consideration of the effects of genomic and clinical data. This design enhances the model's prognostic accuracy by preventing the potential oversight of clinically relevant features. This comprehensive training process ensures that the Cox-PASNet is trained effectively on survival data and can provide accurate predictions and meaningful insights.

To the matrices obtained above, following information are appended so that it can be fed into the proposed architecture. The information includes:

  1. Overall Survival Months
  2. Overall Survival Status
  3. Age
  4. TMB-Tumor Mutation Burden
  5. MSI Sensor Score

This is quite beneficial for Cox-PH model for finding out prognostic Index. This statistical method is employed to examine the correlation between an individual's survival time and one or more predictor variables when the proportional hazards assumption is true, the Cox Proportional-Hazards model is especially helpful because it does not make any assumptions about the survival time distribution. The Cox model makes the assumption that the hazard ratios between any two people remain constant across time, and the hazard function depicts the instantaneous failure rate at any given time. The output of this module is prognostic index that act as input to the last model that is survival analysis.

Prognostic index (PI): This index helps predict how long individuals in our study might survive based on various factors. We then divided people into two groups - one with a higher risk of an event (like death) and another with a lower risk, using their PI values. To check if this division was meaningful, we used a test called the log-rank test. This test compares how long people in the high-risk and low-risk groups survive to see if there's a significant difference. The test gives us a p-value, and if it's small (typically less than 0.05), it suggests that our prognostic index is good at separating people into groups with different survival experiences. This helps us understand who might be at a higher or lower risk in our study, which is important for making informed decisions about treatments and care.

(4). Cox Layer:

Objective Function: To conduct Cox proportional hazards regression within the Cox layer, Cox-RWRet formulates the objective function by employing the average negative log partial likelihood along with L2 regularization.

Negative log partial likelihood: The negative log partial likelihood (NLP) is a term commonly used in the context of survival analysis, particularly when dealing with censored data. It is often used to estimate parameters in a survival model. The negative log partial likelihood is minimized to find the parameter values that best fit the observed survival data.

When incorporating L2 regularization (also known as ridge regularization or Tikhonov regularization), the objective function to be minimized becomes a combination of the negative log partial likelihood and a regularization term. The L2 regularization term penalizes large values of the coefficients to prevent overfitting.

image

Training RWR-PASNet:

i. Initialization:

  • Initialize the RWR-PASNet model with specified architecture parameters such as the number of input nodes (In_Nodes), hidden nodes ('Hidden_Nodes'), and output nodes (Out_Nodes).
  • Define the hyperparameters, including the initial learning rate (Initial_Learning_Rate), L2 regularization lambda (L2_Lambda), and the number of training epochs (Num_EPOCHS).
  • Set up the optimizer, in this case using the Adam optimizer, to update the model parameters during training.

ii. Load Data:

  • Load the training data (x_train, age_train, msi_train, tmb_train, ytime_train, yevent_train) and validation data (x_valid, age_valid, msi_valid, tmb_valid, ytime_valid, yevent_valid) using the load_data function.
  • The data includes genomic inputs (x), age, MSI (Microsatellite Instability) score, TMB (Tumor Mutational Burden) score, survival time (ytime), and censoring status (yevent).

iii. Grid Search for Hyperparameters:

  • Perform a grid search over hyperparameters (L2 regularization and learning rate) using the training and validation data.
  • Train the RWR-PASNet for each combination of hyperparameters and evaluate the model's performance on the validation set.
  • Choose the hyperparameters that result in the best validation loss.

iv. Training Loop:

  • Execute the training loop for the specified number of epochs (Num_EPOCHS).
  • For each epoch:
    • Set the model to training mode and zero out the gradients (opt.zero_grad()).
    • Forward pass: Pass the training data through the model (net) to obtain predictions.
    • Compute the negative partial log-likelihood loss using the neg_par_log_likelihood function.
    • Backward pass: Compute gradients and update the model parameters using the Adam optimizer (opt.step()).

v. Model Evaluation:

  • After training, evaluate the model on the test data (x_test, age_test, msi_test, tmb_test, ytime_test, yevent_test).
  • Calculate the negative partial log-likelihood and concordance index for survival analysis using the trained model.

vi. Optimal Model Selection:

  • Select the model with the best hyperparameters based on the performance on the validation set.
  • Use the selected model for further evaluation on the test set.

vii. Model Interpretation:

  • Interpret the trained RWR-PASNet model to gain insights into feature importance and contributions.
  • Save the model weights and node values for further analysis.

viii. Results:

  • Report the key results, including optimal hyperparameters, model performance on the test set, and any additional insights gained from model interpretation.

ix. Save Model:

  • Save the trained RWR-PASNet model, including its state dictionary, for future use or deployment.

Step 4: Survival Analysis

The Kaplan-Meier estimator is a non-parametric technique for estimating the survival function. The prognostic Index obtained from above is given to Kaplan-Meier along with time-to-event data and event type to find overall survivability of the patients. Based on threshold set for median value of PI, two groups are formed for the patients: one below and one above threshold and survival curves are obtained and p-value that signifies survival probabilities are also obtained.

The data have been classified into two groups: high-risk and low-risk, based on the Prognostic Index (PI). Prognostic Index is calculated with the Cox-RWRNet method. If a person's Prognostic Index is higher than the middle value (median), they're in the high-risk group. If it's lower than the median, they're in the low-risk group.

We evaluated the performance of Cox-RWRNet using TCGA data on BRCA, GBM, LUNG and ovarian cancer. The results of our experiments demonstrated that Cox-RWRNet exhibited enhanced efficacy compared to other survival methods like Deepsurv, DeepHit and Cox-PASNET. We conducted statistical assessments to validate its predictive capabilities.

Explanation of CoxRWRnet parameters:

RWR_data_bulder.py

  1. load_data():

    • Reads clinical patient, clinical sample, and mutation data from three different files (data_clinical_patient.txt, data_clinical_sample.txt, and data_mutations.txt).
    • Returns the dataframes.
  2. clean_data(df_data_mut):

    • Cleans mutation data by extracting the prefix from the 'Tumor_Sample_Barcode' column and saving it to a new CSV file named "clean_tcga.csv".
    • Modifies the 'Tumor_Sample_Barcode' column in the original dataframe.
  3. merge_data(df_clinical_patient, df_clinical_sample, df_data_mut_final):

    • Merges clinical patient, clinical sample, and cleaned mutation data based on patients ID.
    • Returns the merged and filtered dataframe.
  4. get_all_mutated_genes(df_data_mut):

    • Groups mutation data by 'Tumor_Sample_Barcode' and returns a dataframe with unique mutated genes for each sample.
  5. get_deleted_genes(df_data_mut):

    • Reads data from "output_9606.protein.links.full.v11.5.txt" and identifies deleted genes.
    • Returns a list of deleted genes and the original list of genes.
  6. remove_items(test_list, item):

    • Removes specified items from a list.
  7. calculate_seed(df_data_mut_patient_all_mutated, df_new, del_genes_all):

    • Filters out deleted genes from the mutation data and saves the result to "patient_id_updated.csv".
    • Groups the updated mutation data by 'Tumor_Sample_Barcode' and returns a dataframe.
  8. calculate_sg_score(list_of_genes, list_of_mutated_genes):

  • Calculates a score for each gene based on its presence in the mutation data.
  • Saves the results to "score_sg_all_tcga_brca_all.csv".
  • Returns lists of mutated genes and their scores.
  • scores has not been use in code
  1. run_r_script(input_csv_sg, input_csv_seed, output_csv):

    • Executes an R script using rpy2, performing calculations based on input CSV files and saving the results to an output CSV file.
    • Returns generated CSV file and dataframe.
  2. Extract(lst):

    • Extracts the first element from each list.
  3. process_and_clean_data(df_data_mut_patient_all_mutated, df_data_mut_patient_all_data_new, csv_output_path, data):

    • Processes and cleans additional columns from the mutation data, adds new columns, and saves the cleaned data to a CSV file.
    • Returns the cleaned dataframe.
  4. split_data(data, train_ratio, test_ratio, validation_ratio, random_state):

    • Splits the data into training, testing, and validation sets based on specified ratios.
    • Prints the size of each set and returns the sets.
  5. save_data_to_csv(data, filename):

    • Saves the provided dataframe to a CSV file.
  6. call_funs():

    • Calls all the defined functions in sequence to load, clean, merge, calculate, and process data.
    • Saves the training, testing, and validation sets to CSV files.

dataloader.py

  • load_sorted_data(file_path, tensor_dtype):

  • Calls the sort_genomic_clinical_data function to obtain sorted data.

  • Converts the sorted data to PyTorch tensors with the specified tensor_dtype.

  • Returns the following PyTorch tensors: X: Genomic inputs. YTIME: Survival time. YEVENT: Censoring status. AGE: Age data. MSI: MSI data. TMB: TMB data.

model.py

  1. CoxPhRWRNet Class:
    • Constructor (__init__):

      • Initializes the neural network architecture.
      • Parameters:
        • input_nodes: Number of input nodes (features) for genomic data.
        • hidden_nodes1: Number of nodes in the first hidden layer.
        • hidden_nodes2: Number of nodes in the second hidden layer.
        • output_nodes: Number of nodes in the output layer.
    • Attributes:

      • tanh: The hyperbolic tangent activation function (nn.Tanh()).

      • rwr_layer: The linear layer for the Random Walk Restart (RWR) method, transforming genomic input.

      • hidden_layer1: The first hidden linear layer.

      • hidden_layer2: The second hidden linear layer.

      • cox_layer: The linear layer for Cox Proportional Hazard model, combining the RWR output and additional features (age, MSI, TMB).

    • Methods (forward):

      • Defines the forward pass of the neural network.
      • Takes genomic data (x_genomic), MSI data (x_msi), TMB data (x_tmb), and age data (x_age) as inputs.
      • Applies the tanh activation function to the RWR layer and the two hidden layers.
      • Concatenates the hidden layer 2 output with additional features.
      • Passes the combined features through the Cox Proportional Hazard layer.
      • Returns the output of the Cox Proportional Hazard layer.

This architecture is designed for survival analysis, combining genomic data with additional features and using the Random Walk Restart method in the initial layers. The final layer produces the output for the Cox Proportional Hazard model.

train.py

  1. Trains a Cox proportional hazards model with a neural network structure.

    Parameters:

    • train_x, eval_x: Training and evaluation input features.

    • train_age, eval_age: Age information for training and evaluation.

    • train_ytime, eval_ytime: Time-to-event for training and evaluation.

    • train_yevent, eval_yevent: Event indicator for training and evaluation.

    • train_msi, eval_msi: Microsatellite instability for training and evaluation.

    • train_tmb, eval_tmb: Tumor mutational burden for training and evaluation.

    • In_Nodes, hidden_nodes1, hidden_nodes2, Out_Nodes: Neural network architecture parameters.

    • Learning_Rate: Learning rate for optimization.

    • L2: L2 regularization parameter.

    • Num_Epochs: Number of training epochs.

    Returns:

    • Tuple containing training loss, evaluation loss, training concordance index, and evaluation concordance index.

Run_train.py