diff --git a/INFO.txt b/INFO.txt deleted file mode 100644 index 587ee08..0000000 --- a/INFO.txt +++ /dev/null @@ -1,7 +0,0 @@ -All the files are placeholder functions to illustrate the functionality implemented in the directory. The full directory will be uploaded soon. - -* main_train: the main training process, the trained neural network model will be saved in model_result. -* eval_simulation: generate DeepSIF imaging results on synthetic data, evaluate the model performance. -* eval_real: generate DeepSIF imaging results on real data. -* forward/run_tvb: generate synthetic training and texting data. -* anatomy: anatomical information used. See anatomy/README.md for more information. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..1930c52 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,32 @@ +The Clear BSD License + +Copyright (c) 2022, authors of DeepSIF +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted (subject to the limitations in the disclaimer +below) provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md index c6805d3..1985ebd 100644 --- a/README.md +++ b/README.md @@ -1 +1,95 @@ -Placeholder directory for: Deep Neural Networks Constrained by Neural Mass Models Improve Electrophysiological Source Imaging of Spatio-temporal Brain Dynamics +# DeepSIF: Deep Learning based Source Imaging Framework + + +DeepSIF is an EEG/MEG source imaging framework aiming at providing an estimation of the location, size, and temporal activity of the brain activities from scalp EEG/MEG recordings. There are three components: training data generation (```forward/```), neural network training (```main.py```), and model evaluation (```eval_sim.py```,```eval_recal.py```), as detailed below. The codes are provided as a service to the scientific community, and should be used at users’ own risks. + + +This work was supported in part by the National Institutes of Health grants NS096761, EB021027, AT009263, MH114233, EB029354, and NS124564, awarded to Dr. Bin He, Carnegie Mellon University. Additional data in 20 human epilepsy patients tested in this work can be found at + + +https://figshare.com/s/580622eaf17108da49d7. + + +Please cite the following publication if you are using any part of the codes: + +Sun R, Sohrabpour A, Worrell GA, He B: “Deep Neural Networks Constrained by Neural Mass Models Improve Electrophysiological Source Imaging of Spatio-temporal Brain Dynamics.” Proceedings of the National Academy of Sciences, 2022. + + + +## Train Data Generation +#### The Virtual Brain Simulation +```bash +python generate_tvb_data.py --a_start 0 --a_end 10 +``` +The simulation for each region can also run in parallel. (Require multiprocessing installed.) + +#### Process Raw TVB Data and Prepare Training/Testing Dataset +Run in Matlab +```matlab +process_raw_nmm +generate_sythetic_source +``` +The output of ```generate_sythetic_source``` can be used as input for ```loaders.SpikeEEGBuild``` or ```loaders.SpikeEEGBuildEval``` + +## Training + +After sythetic training dataset is generated, ```main.py``` can be used to train a DeepSIF model. ```network.py``` contains the architecture used + in the paper. ```loaders.py``` provides two ways to load the dataset. If the data is already saved in seperate input/output files + , ```SpikeEEGLoad``` can be used. If training data is generated on the run, ```SpikeEEGBuild``` can be used to generate different types of + training data. To train a model, use + +```bash +python main.py --model_id 1 +``` +**Parameters:** + + '--save', type=int, default=True, help='save each epoch or not' + '--workers', default=0, type=int, help='number of data loading workers' + '--batch_size', default=64, type=int, help='batch size' + '--device', default='cuda:0', type=str, help='device running the code' + '--arch', default='TemporalInverseNet', type=str, help='network achitecture class' + '--dat', default='SpikeEEGBuild', type=str, help='data loader class' + '--train', default='test_sample_source2.mat', type=str, help='train dataset name or directory' + '--test', default='test_sample_source2.mat', type=str, help='test dataset name or directory' + '--model_id', default=75, type=int, help='model id' + '--lr', default=3e-4, type=float, help='learning rate' + '--resume', default='1', type=str, help='epoch id to resume' + '--epoch', default=20, type=int, help='total number of epoch' + '--fwd', default='leadfield_75_20k.mat', type=str, help='forward matrix to use' + '--rnn_layer', default=3, type=int, help='number of rnn layer' + '--info', default='', type=str, help='other information regarding this model' + +## Evaluation + +#### Simulation : +After a model is trained, ```eval_sim.py``` can be used to evaluate the trained model in simulations under different conditions. Some examples are: +```bash +python eval_sim.py --model_id 75 +``` +Additional tests: narrow-band input +```bash +python eval_sim.py --model_id 75 --lfreq 1 --hfreq 3 +``` +Additional tests: different noise type +```bash +python eval_sim.py --model_id 75 --snr_rsn_ratio 0.5 +``` +Additional tests: different head / conductivity value / electrode locations +```bash +python eval_sim.py --model_id 75 --fwd +``` +#### Real data : +Or use real data as the model input as shown in ```eval_real.py```: +```bash +python eval_real.py +``` +Default subject folder : VEP + + +## Dependencies +* [Python >= 3.8.3](https://www.python.org/downloads/) +* [PyTorch>=1.6.0](https://pytorch.org/) +* [tvb](https://www.thevirtualbrain.org/tvb/zwei) +* [mne](https://mne.tools/stable/index.html) +* [h5py](https://www.h5py.org/) +* [numpy](https://numpy.org/) \ No newline at end of file diff --git a/anatomy/README.md b/anatomy/README.md new file mode 100644 index 0000000..4d686b2 --- /dev/null +++ b/anatomy/README.md @@ -0,0 +1,38 @@ +### This folder contains miscellaneous files related to the forward process +#### For TVB simulation +* connectivity_76.zip +* connectivity_998.zip + +Connectivity file used for the TVB simulation, downloaded from +https://github.com/the-virtual-brain/tvb-data + +#### Template headmodel: +- **fs_cortex_20k.mat**: fsaverage5 cortex + - pos, tri, ori: vertices positions, triangulations, and dipole orientations; resolution 20484 + - left_ind, left_tri: index for left hemisphere, triangulations in the left hemisphere + - right_ind, right_tri: index for right hemisphere, triangulations in the right hemisphere + +- **fs_cortex_20k_inflated.mat**: inflated fsaverage5 cortex + - pos, tri: vertices positions, triangulations; resolution 20484 + - posl, tril: vertices positions and triangulations in the left hemisphere + - posr, trir: vertices positions and triangulations in the right hemisphere + +- **fs_cortex_20k_region_mapping.mat** : map fsaverage5 to 994 regions + - rm: region mapping id, size 1*20484 + - nbs: neighbours for each region + +- **leadfield_75_20k.mat** : leadfield matrix for fsaverage5, 75 channels + - fwd: size 75*994 + +- **dis_matrix_fs_20k.mat** : distance between source centres + - raw_dis_matrix: size 994*994 + +- **electrode_75.mat** : 75 EEG channels in EEGLAB format + - eloc75 + +- **fsaverage5/**: contains the files for the raw freesurfer output, for plotting in mne + +#### Simulations: +- **realistic_noise.mat** : resting data extracted from 75 channel EEG recordings + - data: num_examples * num_time* num_channel; 4*500*75 + - npower: the power for each channel; 4*75 \ No newline at end of file diff --git a/anatomy/connectivity_76.zip b/anatomy/connectivity_76.zip new file mode 100644 index 0000000..70bc71e Binary files /dev/null and b/anatomy/connectivity_76.zip differ diff --git a/anatomy/connectivity_998.zip b/anatomy/connectivity_998.zip new file mode 100644 index 0000000..8121fac Binary files /dev/null and b/anatomy/connectivity_998.zip differ diff --git a/anatomy/dis_matrix_fs_20k.mat b/anatomy/dis_matrix_fs_20k.mat new file mode 100644 index 0000000..07f1dde Binary files /dev/null and b/anatomy/dis_matrix_fs_20k.mat differ diff --git a/anatomy/electrode_75.mat b/anatomy/electrode_75.mat new file mode 100644 index 0000000..02b681b Binary files /dev/null and b/anatomy/electrode_75.mat differ diff --git a/anatomy/fs_cortex_20k.mat b/anatomy/fs_cortex_20k.mat new file mode 100644 index 0000000..a382ad6 Binary files /dev/null and b/anatomy/fs_cortex_20k.mat differ diff --git a/anatomy/fs_cortex_20k_inflated.mat b/anatomy/fs_cortex_20k_inflated.mat new file mode 100644 index 0000000..79f105b Binary files /dev/null and b/anatomy/fs_cortex_20k_inflated.mat differ diff --git a/anatomy/fs_cortex_20k_region_mapping.mat b/anatomy/fs_cortex_20k_region_mapping.mat new file mode 100644 index 0000000..c9d3d0e Binary files /dev/null and b/anatomy/fs_cortex_20k_region_mapping.mat differ diff --git a/anatomy/fsaverage5/surf/lh.curv b/anatomy/fsaverage5/surf/lh.curv new file mode 100644 index 0000000..5bf663c Binary files /dev/null and b/anatomy/fsaverage5/surf/lh.curv differ diff --git a/anatomy/fsaverage5/surf/lh.pial b/anatomy/fsaverage5/surf/lh.pial new file mode 100644 index 0000000..ac8cd91 Binary files /dev/null and b/anatomy/fsaverage5/surf/lh.pial differ diff --git a/anatomy/fsaverage5/surf/rh.curv b/anatomy/fsaverage5/surf/rh.curv new file mode 100644 index 0000000..60e603e Binary files /dev/null and b/anatomy/fsaverage5/surf/rh.curv differ diff --git a/anatomy/fsaverage5/surf/rh.pial b/anatomy/fsaverage5/surf/rh.pial new file mode 100644 index 0000000..7456d05 Binary files /dev/null and b/anatomy/fsaverage5/surf/rh.pial differ diff --git a/anatomy/leadfield_75_20k.mat b/anatomy/leadfield_75_20k.mat new file mode 100644 index 0000000..d1a261e Binary files /dev/null and b/anatomy/leadfield_75_20k.mat differ diff --git a/anatomy/realistic_noise.mat b/anatomy/realistic_noise.mat new file mode 100644 index 0000000..a227785 Binary files /dev/null and b/anatomy/realistic_noise.mat differ diff --git a/eval_real.py b/eval_real.py new file mode 100644 index 0000000..77d99fb --- /dev/null +++ b/eval_real.py @@ -0,0 +1,80 @@ +import argparse +import os +import time +from scipy.io import loadmat, savemat +import numpy as np +import glob + +import torch +import network + + +def main(): + start_time = time.time() + # parse the input + parser = argparse.ArgumentParser(description='DeepSIF Model') + parser.add_argument('--device', default='cpu', type=str, help='device running the code') + parser.add_argument('--model_id', type=int, default=64, help='model id') + parser.add_argument('--resume', default='', type=str, help='epoch id to resume') + parser.add_argument('--info', default='', type=str, help='other information regarding this model') + args = parser.parse_args() + + # ======================= PREPARE PARAMETERS ===================================================================================================== + use_cuda = (False) and torch.cuda.is_available() # Only use GPU during training + device = torch.device(args.device if use_cuda else "cpu") + result_root = 'model_result/{}_the_model'.format(args.model_id) + if not os.path.exists(result_root): + print("ERROR: No model {}".format(args.model_id)) + return + + # =============================== LOAD MODEL ===================================================================================================== + if args.resume: + fn = fn = os.path.join(result_root, 'epoch_' + args.resume) + else: + fn = os.path.join(result_root, 'model_best.pth.tar') + print("=> Load checkpoint", fn) + if os.path.isfile(fn): + print("=> Found checkpoint '{}'".format(fn)) + checkpoint = torch.load(fn, map_location=torch.device('cpu')) + best_result = checkpoint['best_result'] + net = network.__dict__[checkpoint['arch']](*checkpoint['attribute_list']).to(device) # redefine the weights architecture + net.load_state_dict(checkpoint['state_dict'], strict=False) + print("=> Loaded checkpoint {}, current results: {}".format(fn, best_result)) + else: + print("ERROR: no checkpoint found") + return + + print('Number of parameters:', net.count_parameters()) + print('Prepare time:', time.time() - start_time) + + # =============================== EVALUATION ===================================================================================================== + net.eval() + subject_list = ['VEP'] + for pii in subject_list: + folder_name = 'source/{}'.format(pii) + start_time = time.time() + flist = glob.glob(folder_name + '/data*.mat') + if len(flist) == 0: + print('WARNING: NO FILE IN FOLDER {}.'.format(folder_name)) + continue + flist = sorted(flist, key=lambda name: int(os.path.basename(name)[4:-4])) # sort file based on nature number + test_data = [] + for i in flist: + data = loadmat(i)['data'] + # data = data - np.mean(data, 0, keepdims=True) + # data = data - np.mean(data, 1, keepdims=True) + data = data / np.max(np.abs(data[:])) + test_data.append(data) + + data = torch.from_numpy(np.array(test_data)).to(device, torch.float) + out = net(data)['last'] + # calculate the loss + all_out = out.detach().cpu().numpy() + # visualize the result in Matlab + savemat(folder_name + '/rnn_test_{}_{}.mat'.format(args.model_id, fn[-8:]), {'all_out': all_out}) + print('Save output as:', folder_name + '/rnn_test_{}_{}.mat'.format(args.model_id, fn[-8:])) + print('Total run time:', time.time() - start_time) + +if __name__ == '__main__': + main() + diff --git a/eval_real_data.exe b/eval_real_data.exe deleted file mode 100644 index bab17f2..0000000 Binary files a/eval_real_data.exe and /dev/null differ diff --git a/eval_sim.py b/eval_sim.py new file mode 100644 index 0000000..1ebc4f3 --- /dev/null +++ b/eval_sim.py @@ -0,0 +1,139 @@ +import argparse +import os +import time +from scipy.io import loadmat, savemat +import numpy as np +import logging +import datetime +import collections + +import torch +from torch.utils.data import DataLoader + +import network +import loaders +from utils import get_otsu_regions + + +def main(): + start_time = time.time() + # parse the input + parser = argparse.ArgumentParser(description='DeepSIF Model') + parser.add_argument('--workers', default=0, type=int, help='number of data loading workers') + parser.add_argument('--batch_size', default=64, type=int, help='batch size') + parser.add_argument('--device', default='cuda:0', type=str, help='device running the code') + parser.add_argument('--dat', default='SpikeEEGBuildEval', type=str, help='data loader') + parser.add_argument('--test', default='test_sample_source2.mat', type=str, help='test dataset name') + parser.add_argument('--model_id', type=int, default=75, help='model id') + parser.add_argument('--resume', default='', type=str, help='epoch id to resume') + parser.add_argument('--fwd', default='leadfield_75_20k.mat', type=str, help='forward matrix to use') + parser.add_argument('--info', default='', type=str, help='other information regarding this model') + + parser.add_argument('--snr_rsn_ratio', default=0, type=float, help='ratio between real noise and gaussian noise') + parser.add_argument('--lfreq', default=-1, type=int, help='filter EEG data to perform narrow-band analysis') + parser.add_argument('--hfreq', default=-1, type=int, help='filter EEG data to perform narrow-band analysis') + args = parser.parse_args() + + # ======================= PREPARE PARAMETERS ===================================================================================================== + use_cuda = (False) and torch.cuda.is_available() # Only use GPU during training + device = torch.device(args.device if use_cuda else "cpu") + + data_root = 'source/Simulation/' + dis_matrix = loadmat('anatomy/dis_matrix_fs_20k.mat')['raw_dis_matrix'] + + result_root = 'model_result/{}_the_model'.format(args.model_id) + if not os.path.exists(result_root): + print("ERROR: No model {}".format(args.model_id)) + return + fwd = loadmat('anatomy/{}'.format(args.fwd))['fwd'] + + # ================================== LOAD DATA =================================================================================================== + test_data = loaders.__dict__[args.dat](data_root + args.test, fwd=fwd, + args_params={'snr_rsn_ratio': args.snr_rsn_ratio, + 'lfreq': args.lfreq, 'hfreq': args.hfreq}) + test_loader = DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=False) + + # =============================== LOAD MODEL ===================================================================================================== + if args.resume: + fn = fn = os.path.join(result_root, 'epoch_' + args.resume) + else: + fn = os.path.join(result_root, 'model_best.pth.tar') + print("=> Load checkpoint", fn) + if os.path.isfile(fn): + print("=> Found checkpoint '{}'".format(fn)) + checkpoint = torch.load(fn, map_location=torch.device('cpu')) + best_result = checkpoint['best_result'] + net = network.__dict__[checkpoint['arch']](*checkpoint['attribute_list']).to(device) # redefine the weights architecture + net.load_state_dict(checkpoint['state_dict'], strict=False) + print("=> Loaded checkpoint {}, current results: {}".format(fn, best_result)) + + # Define logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + handler = logging.FileHandler(result_root + '/outputs_{}.log'.format(checkpoint['arch'])) + handler.setLevel(logging.INFO) + logger.addHandler(handler) + logger.info("=================== Evaluation mode: {} ====================================".format(datetime.datetime.now())) + logger.info("Testing data is {}".format(args.test)) + # Save every parameters in args + for v in args.__dict__: + if v not in ['workers', 'train', 'test']: + logger.info('{} is {}'.format(v, args.__dict__[v])) + else: + print("ERROR: no checkpoint found") + return + + print('Number of parameters:', net.count_parameters()) + print('Prepare time:', time.time() - start_time) + + # =============================== EVALUATION ===================================================================================================== + net.eval() + + eval_dict = collections.defaultdict(list) + eval_dict['all_out'] = [] # DeepSIF output + eval_dict['all_nmm'] = [] # Ground truth source activity + eval_dict['all_regions'] = [] # DeepSIF identified source regions + eval_dict['all_loss'] = 0 # MSE Loss + criterion = torch.nn.MSELoss(reduction='sum') + + with torch.no_grad(): + + for batch_idx, sample_batch in enumerate(test_loader): + + if batch_idx > 0: + break + + data = sample_batch['data'].to(device, torch.float) + nmm = sample_batch['nmm'].numpy() + label = sample_batch['label'].numpy() + model_output = net(data) + out = model_output['last'] + # calculate loss function + # nmm_torch = sample_batch['nmm'].to(device, torch.float) + # eval_dict['all_loss'] = eval_dict['all_loss'] + criterion(out, nmm_torch).data.numpy() + # ----- SAVE EVERYTHING TO EXAMINE LATER (not suitable for large test dataset) ------- + # eval_dict['all_out'].append(out.cpu().numpy()) + # eval_dict['all_eeg'].append(data.cpu().numpy()) + + # ----- ONLY SAVE IDENTIFIED REGION -------------------------------------------------- + eval_results = get_otsu_regions(out.cpu().numpy(), label) + # calculate metrics as a sanity check + # eval_results = get_otsu_regions(out.cpu().numpy(), label, args_params = {'dis_matrix': dis_matrix}) + # eval_dict['precision'].extend(eval_results['precision']) + # eval_dict['recall'].extend(eval_results['recall']) + # eval_dict['le'].extend(eval_results['le']) + + eval_dict['all_regions'].extend(eval_results['all_regions']) + eval_dict['all_out'].extend(eval_results['all_out']) + # ------------------------------------------------------------------------------------ + for kk in range(out.size(0)): + eval_dict['all_nmm'].append(nmm[kk, :, label[kk, :, 0]]) # Only save activity in the center region + # lb = label[kk, :, :] # Save activities in all source regions + # eval_dict['all_nmm'].append(nmm[kk, :, lb[np.logical_not(ispadding(lb))]]) + + savemat(fn + '_preds_{}{}.mat'.format(args.test[:-4], args.info), eval_dict) + + +if __name__ == '__main__': + main() + diff --git a/eval_simulation.exe b/eval_simulation.exe deleted file mode 100644 index 43dad7e..0000000 Binary files a/eval_simulation.exe and /dev/null differ diff --git a/forward/README.md b/forward/README.md new file mode 100644 index 0000000..0963754 --- /dev/null +++ b/forward/README.md @@ -0,0 +1,15 @@ +## DeepSIF: Train Data Generation + +### The Virtual Brain Simulation +```bash +python generate_tvb_data.py --a_start 0 --a_end 10 +``` +The simulation for each region can also run in parallel. (Require multiprocessing installed.) + +### Process Raw TVB Data Prepare Training/Testing Dataset +Run in Matlab +```matlab +process_raw_nmm +generate_sythetic_source +``` +The output of ```generate_sythetic_source``` can be used as input for ```loaders.SpikeEEGBuild``` or ```loaders.SpikeEEGBuildEval``` diff --git a/forward/generate_sythetic_source.m b/forward/generate_sythetic_source.m new file mode 100644 index 0000000..b65cb1a --- /dev/null +++ b/forward/generate_sythetic_source.m @@ -0,0 +1,281 @@ +clear +train = 0; +n_sources = 2; +load('../anatomy/fs_cortex_20k_inflated.mat') +load('../anatomy/fs_cortex_20k.mat') +load('../anatomy/fs_cortex_20k_region_mapping.mat'); +% when load mat in python, python cannot read nan properly, so use a magic number to represent nan when saving +NAN_NUMBER = 15213; +MAX_SIZE = 70; +if train + nper = 100; % Number of nmm spike samples + n_data = 40; + n_iter = 48; % The number of variations in each source center + ds_type = 'train'; +else + nper = 10; + n_data = 1; + n_iter = 3; + ds_type = 'test'; +end +%% ======================================================================== +%=============== Generate Source Patch ==================================== +%% ======== Region Growing Get Candidate Source Regions =================== +selected_region_all = cell(994, 1); +for i=1:994 + % get source direction + selected_region_all{i} = []; + region_id =i; + all_nb = cell(1,4); + all_nb{1} = find_nb_rg(nbs, region_id, region_id); % first layer regions + all_nb{2} = find_nb_rg(nbs, all_nb{1}, [region_id, all_nb{1}]); % second layer regions + v0 = get_direction(centre(region_id, :),centre(all_nb{1}, :)); % direction between the center region and first layer neighbors + angs = zeros(size(v0,1),1); + for k=1:size(v0,1) + CosTheta = max(min(dot(v0(1,:),v0(k,:)),1),-1); + angs(k) = real(acosd(CosTheta)); + end + [~,ind] = sort(angs); + ind = ind(1:ceil(length(angs)/2)); % directions to grow the region + % second layer neighbours + for iter = 1:5 + all_rg = cell(1,4); + for k=1:length(ind) + ii = ind(k); + all_rg(1:2) = all_nb(1:2); + v = get_direction(centre(region_id, :),centre(all_rg{1}(ii), :)); + all_rg{2} = all_rg{2}(get_region_with_dir(v, centre(region_id, :), centre(all_rg{2}, :),1,0)); + [add_rg, rm_rg] = smooth_region(nbs, [region_id,all_nb{1},all_rg{2}]); + final_r = setdiff([region_id, all_nb{1}, all_rg{2} add_rg],rm_rg, 'stable')-1; + selected_region_all{i} = [selected_region_all{i};final_r NAN_NUMBER*ones(1,MAX_SIZE-length(final_r))]; + end + end + % third layer neighbours + for iter = 1:5 + all_rg = cell(1,4); + for k=1:length(ind) + ii = ind(k); + all_rg(1:2) = all_nb(1:2); + v = get_direction(centre(region_id, :),centre(all_rg{1}(ii), :)); + all_rg{2} = all_rg{2}(get_region_with_dir(v, centre(region_id, :), centre(all_rg{2}, :),1,0.1)); + all_rg{3} = find_nb_rg(nbs, all_rg{2}, [region_id, all_rg{1}, all_rg{2}]); + all_rg{3} = all_rg{3}(get_region_with_dir(v, centre(region_id, :), centre(all_rg{3}, :),1,-0.15)); + [add_rg, rm_rg] = smooth_region(nbs, [region_id,all_nb{1},all_rg{2},all_rg{3}]); + final_r = setdiff([region_id, all_nb{1}, all_rg{2} all_rg{3} add_rg],rm_rg, 'stable')-1; + selected_region_all{i} = [selected_region_all{i};final_r NAN_NUMBER*ones(1,MAX_SIZE-length(final_r))]; + end + end +% % fourth neighbours +% for iter = 1:5 +% all_rg = cell(1,4); +% for k=1:length(ind) +% ii = ind(k); +% all_rg(1:2) = all_nb(1:2); +% v = get_direction(centre(region_id, :),centre(all_rg{1}(ii), :)); +% all_rg{2} = all_rg{2}(get_region_with_dir(v, centre(region_id, :), centre(all_rg{2}, :),1,0.2)); +% all_rg{3} = find_nb_rg(nbs, all_rg{2}, [region_id, all_rg{1}, all_rg{2}]); +% all_rg{3} = all_rg{3}(get_region_with_dir(v, centre(region_id, :), centre(all_rg{3}, :),1,-0.1)); +% all_rg{4} = find_nb_rg(nbs, all_rg{3}, [region_id, all_rg{1}, all_rg{2}, all_rg{3}]); +% all_rg{4} = all_rg{4}(get_region_with_dir(v, centre(region_id, :), centre(all_rg{4}, :),1,-0.35)); +% [add_rg, rm_rg] = smooth_region(nbs, [region_id,all_nb{1},all_nb{2},all_rg{3},all_rg{4}]); +% final_r = setdiff([region_id, all_nb{1}, all_nb{2}, all_rg{3},all_rg{4}, add_rg],rm_rg, 'stable')-1; +% if length(final_r) < 71 +% selected_region_all{i} = [selected_region_all{i};final_r NAN_NUMBER*ones(1,MAX_SIZE-length(final_r))]; +% end +% end +% end +end +%% ======== Get Region Center for Each Sample ============================= +selected_region = NAN_NUMBER*ones(994*n_iter, n_sources, MAX_SIZE); +n_iter_list = nan(n_iter*(n_sources-1), 994); +for i = 1:n_iter + for k=1:(n_sources-1) + n_iter_list(i+(k-1)*n_iter,:) = randperm(994); + end +end +n_iter_list(n_iter+1,:) = 1:994; +%% ======== Build Source Patch ============================================ +for kk = 1:n_iter + for ii = 1:994 + idx = 994*(kk-1) + ii; + tr = selected_region_all{ii}; + if kk <= size(tr, 1) && train + selected_region(idx,1,:) = tr(kk,:); + else + selected_region(idx,1,:) = tr(randi([1,size(tr,1)],1,1),:); + end + for k=2:n_sources + tr = selected_region_all{n_iter_list(kk+n_iter*(k-2),ii)}; + selected_region(idx,k,:) = tr(randi([1,size(tr,1)],1,1),:); + end + end +end +selected_region_raw = selected_region; +selected_region = reshape(permute(selected_region_raw, [3,2,1]), MAX_SIZE*n_sources, 994, n_iter); +selected_region = permute(selected_region,[1,3,2]); +selected_region = reshape(repmat(selected_region, 4, 1, 1), MAX_SIZE, n_sources, []); % 4 SNR levels +selected_region = permute(selected_region,[3,2,1]); +%% SAVE +dataset_name = 'source1'; +save([ds_type '_sample_' dataset_name '.mat'], 'selected_region') +%% ======================================================================== +%=============== Generate Other Parameters================================= +%% NMM Signal Waveform +random_samples = randi([0,nper-1],994*n_iter*4,n_sources); % the waveform index for each source +nmm_idx = (selected_region(:,:,1)+1)*nper + random_samples + 1; +save([ds_type '_sample_' dataset_name '.mat'],'nmm_idx', 'random_samples', '-append') +%% SNR +current_snr = reshape(repmat(5:5:20,n_iter*994,1)',[],1); +save([ds_type '_sample_' dataset_name '.mat'],'current_snr', '-append') +%% Scaling Factor +load('../anatomy/leadfield_75_20k.mat'); +gt = load([ds_type '_sample_' dataset_name '.mat']); +scale_ratio = []; +n_source = size(gt.selected_region, 2); +parfor i=1:size(gt.selected_region, 1) + for k=1:n_source + a = gt.selected_region(i,k,:); + a = a(:); + a(a>1000) = []; + if train + scale_ratio(i,k,:) = find_alpha(a+1, random_samples(i, k), fwd, 10:2:20); + + else + scale_ratio(i,k,:) = find_alpha(a+1, random_samples(i, k), fwd, [10,15]); + end + end +end +save([ds_type '_sample_' dataset_name '.mat'], 'scale_ratio', '-append') +%% Change Source Magnitude +clear mag_change +point_05 = [40, 60]; % 45,35 % Magnitude falls to half of the centre region +point_05 = randi(point_05); +sigma = 0.8493*point_05; +mag_change = []; +parfor i=1:size(gt.selected_region,1) + for k=1:n_sources + rg = gt.selected_region(i,k,:); + rg(rg>1000) = []; + dis2centre = all_dis(rg(1)+1,rg+1); + mag_change(i,k,:) = [exp(-dis2centre.^2/(2*sigma^2)) NAN_NUMBER*ones(1,size(gt.selected_region,3)-length(rg))]; + end +end +save([ds_type '_sample_' dataset_name '.mat'], 'mag_change', '-append') +%% +function alpha = find_alpha(region_id, nmm_idx, fwd, target_SNR) +% Re-scaling NMM channels in source channels +% +% INPUTS: +% - region_id : source regions, start at 1 +% - nmm_idx : load nmm data +% - fwd : leadfield matrix, num_electrode * num_regions +% - target_SNR : set snr between signal and the background activity. +% OUTPUTS: +% - alpha : the scaling factor for one patch source + +load(['../source/nmm_clip/a' int2str(region_id(1)-1) '/nmm_' int2str(nmm_idx) '.mat']) +spike_shape = data(:,region_id(1))/max(data(:,region_id(1))); +[~, peak_time] = max(spike_shape); +data(:, region_id) = repmat(spike_shape,1,length(region_id)); +[Ps, Pn, ~] = calcualate_SNR(data, fwd, region_id, max(peak_time-50,0):max(peak_time+50,500)); +alpha = sqrt(10.^(target_SNR./10).*Pn./Ps); +end + + +function [Ps, Pn, cur_snr] = calcualate_SNR(nmm, fwd, region_id, spike_ind) +% Caculate SNR at sensor space. +% +% INPUTS: +% - nmm : NMM data with single activation, time * channel +% - fwd : leadfield matrix, num_electrode * num_regions +% - region_id : source regions, start at 1, region_id(1) is the center +% - spike_ind : index to calculate the spike snr +% OUTPUTS: +% - Ps : signal power +% - Pn : noise power +% - cur_snr : current SNR in dB + + sig_eeg = (fwd(:, region_id)*nmm(:, region_id)')'; % time * channel + sig_eeg_rm = sig_eeg - mean(sig_eeg, 1); + dd = 1:size(nmm,2); + dd(region_id) = []; + noise_eeg = (fwd(:,dd)*nmm(:,dd)')'; + noise_eeg_rm = noise_eeg - mean(noise_eeg, 1); + + Ps = norm(sig_eeg_rm(spike_ind,:),'fro')^2/length(spike_ind); + Pn = norm(noise_eeg_rm(spike_ind,:),'fro')^2/length(spike_ind); + cur_snr = 10*log10(Ps/Pn); +end + + +function v = get_direction(a,b) +% Calculate direction between two point +% INPUTS: +% - a,b : points in 3D; size 1*3 +% OUTPUTS: +% - v : direction between two points; size 1*3 +v = b-a; +v = v./mynorm(v,2); +end + + +function rg = find_nb_rg(nbs, centre_rg, prev_layers) +% Find the neighbouring regions of the centre region +% INPUTS: +% - nbs : neighbour regions for each cortical region; 1*994 +% - centre_rg : centre regions +% - prev_layers : regions in inner layers +% OUTPUTS: +% - rg : neighbouring regions +rg = unique(cell2mat(nbs(centre_rg))); +rg(ismember(rg, prev_layers)) = []; +end + + +function [selected_rg] = get_region_with_dir(v, region_centre, nb_points, ratio, bias) +% Select region given the region growing direction +% INPUTS: +% - v : region growing direction +% - region_centre : centre region in 3D +% - nb_points : neighbour region in 3D +% - ratio, bias : adjust the probability of selecting neighbour +% regions (numbers decided by trial and error) +% OUTPUTS: +% - selected_rg : selected neighbouring regions + + v2 = get_direction(region_centre, nb_points); % direction between center region and neighbour regions + dir_range = abs(v2*v'); % dot product between region growing direction and all neighbouring directions + dir_range = ratio*((dir_range-min(dir_range))/(max(dir_range) - min(dir_range))) + bias; % the probability of selecting neighbour regions +% dir_range = 0.5 % Equal probability for all directions + selected_rg = rand(length(dir_range),1) < dir_range; +end + + + +function [add_rg, rm_rg] = smooth_region(nbs, current_regions) +% Clean up the current selected regions; since we randomly select the +% neighbouring regions, there could be "holes" in the source patch. We add +% the regions where all its neighbours are in the current source patch; and +% remove the regions where no neighbours is in current source patch; +% INPUTS: +% - nbs : neighbour regions for each cortical region; 1*994 +% - current_regions : selected regions +% OUTPUTS: +% - selected_rg : selected neighbouring regions + add_rg = []; + rm_rg = []; + all_final_nb = find_nb_rg(nbs, current_regions, []); + all_final_nb = setdiff(all_final_nb, current_regions); + for i=1:length(all_final_nb) + current_rg = all_final_nb(i); + if length(intersect(current_regions, nbs{current_rg})) > length(nbs{current_rg})-2 + add_rg = [add_rg current_rg]; + end + end + for i=1:length(current_regions) + current_rg = current_regions(i); + if length(intersect([current_regions,add_rg], nbs{current_rg})) == 1 + rm_rg = [rm_rg current_rg]; + end + end +end \ No newline at end of file diff --git a/forward/generate_tvb_data.py b/forward/generate_tvb_data.py new file mode 100644 index 0000000..4f4aec0 --- /dev/null +++ b/forward/generate_tvb_data.py @@ -0,0 +1,90 @@ +from scipy.io import savemat +from tvb.simulator.lab import * +import time +import numpy as np +import multiprocessing as mp +import os +import argparse + + +def main(region_id): + """ TVB Simulation to generate raw source space dynamics, unit in mV, and ms + :param region_id: int; source region id, with parameters generating interictal spike activity + """ + if not os.path.isdir('../source/raw_nmm/a{}/'.format(region_id)): + os.mkdir('../source/raw_nmm/a{}/'.format(region_id)) + start_time = time.time() + print('------ Generate data of region_id {} ----------'.format(region_id)) + conn = connectivity.Connectivity.from_file(source_file=os.getcwd()+'/../anatomy/connectivity_76.zip') # connectivity provided by TVB + conn.configure() + + # define A value + num_region = conn.number_of_regions + a_range = [3.5] + A = np.ones((num_region, len(a_range))) * 3.25 # the normal A value is 3.25 + A[region_id, :] = a_range + + # define mean and std + mean_and_std = np.array([[0.087, 0.08, 0.083], [1, 1.7, 1.5]]) + for iter_a in range(A.shape[1]): + use_A = A[:, iter_a] + for iter_m in range(mean_and_std.shape[1]): + + jrm = models.JansenRit(A=use_A, mu=np.array(mean_and_std[0][iter_m]), + v0=np.array([6.]), p_max=np.array([0.15]), p_min=np.array([0.03])) + phi_n_scaling = (jrm.a * 3.25 * (jrm.p_max - jrm.p_min) * 0.5 * mean_and_std[1][iter_m]) ** 2 / 2. + sigma = np.zeros(6) + sigma[4] = phi_n_scaling + + # set the random seed for the random intergrator + # randomStream = np.random.mtrand.RandomState(0) + # noise_class = noise.Additive(random_stream=randomStream, nsig=sigma) + # integ = integrators.HeunStochastic(dt=2 ** -1, noise=noise_class) + + sim = simulator.Simulator( + model=jrm, + connectivity=conn, + coupling=coupling.SigmoidalJansenRit(a=np.array([1.0])), + integrator=integrators.HeunStochastic(dt=2 ** -1, noise=noise.Additive(nsig=sigma)), + monitors=(monitors.Raw(),) + ).configure() + + # run 200s of simulation, cut it into 20 pieces, 10s each. (Avoid saving large files) + for iii in range(20): + siml = 1e4 + out = sim.run(simulation_length=siml) + (t, data), = out + data = (data[:, 1, :, :] - data[:, 2, :, :]).squeeze().astype(np.float32) + + # # in the fsaverage5 mapping, there is no vertices corresponding to region 7,325,921, 949, so change label 994-998 to those id + # data[:, 7] = data[:, 994] + # data[:, 325] = data[:, 997] + # data[:, 921] = data[:, 996] + # data[:, 949] = data[:, 995] + # data = data[:, :994] + + savemat('../source/raw_nmm/a{}/mean_iter_{}_a_iter_{}_{}.mat'.format(region_id, iter_m, region_id, iii), + {'time': t, 'data': data, 'A': use_A}) + print('Time for', region_id, time.time() - start_time) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='TVB Simulation') + parser.add_argument('--a_start', type=int, default=0, metavar='t/f', help='start region id') + parser.add_argument('--a_end', type=int, default=1, metavar='t/f', help='end region id') + args = parser.parse_args() + os.environ["MKL_NUM_THREADS"] = "1" + start_time = time.time() + # RUN THE CODE IN PARALLEL + # processes = [mp.Process(target=main, args=(x,)) for x in range(args.a_start, args.a_end)] + # for p in processes: + # p.start() + # # Exit the completed processes + # for p in processes: + # p.join() + # NO PARALLEL + for x in range(args.a_start, args.a_end): + main(x) + print('Total_time', time.time() - start_time) + diff --git a/forward/process_raw_nmm.m b/forward/process_raw_nmm.m new file mode 100644 index 0000000..91c103f --- /dev/null +++ b/forward/process_raw_nmm.m @@ -0,0 +1,222 @@ +function process_raw_nmm(filename, varargin) +% Scan through the raw NMM data to find the spike data + +% %%%%%%%%%%%%%% SETUP PARAMETERS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +p = inputParser; +addRequired(p,'filename',@ischar); +addParameter(p,'leadfield_name','leadfield_75_20k.mat', @ischar); +parse(p, filename, varargin{:}) +filename = p.Results.filename; +headmodel = load(['../anatomy/' p.Results.leadfield_name]); +fwd = headmodel.fwd; +savefile_path = '../source/'; + +% ------------------------------------------------------------------------- +iter_list = 1:5; % the iter during NMM generation. +previous_iter_spike_num = zeros(1, 994); +for i_iter = 1:length(iter_list) + iter = iter_list(i_iter); + if isempty(dir([savefile_path 'nmm_' filename '/clip_info/iter' int2str(iter)])) + mkdir([savefile_path 'nmm_' filename '/clip_info/iter' int2str(iter)]) + end + + % ------- Resume running if the process was interupted ---------------- + done = dir([savefile_path 'nmm_' filename '/clip_info/iter' int2str(iter) '/iter_' int2str(iter) '_i_*']); + finished_regions = zeros(1, length(done)); + for i = 1:length(done) + finished_regions(i) = str2num(done(i).name(10:end-3)); + end + remaining_regions = setdiff(1:994, finished_regions+1); + if isempty(remaining_regions) + continue; + end + + % -------- start the main progress -----------------------------------% + for ii = 183:183%length(remaining_regions) + + i = remaining_regions(ii); + % creat folders to save nmm files + if isempty(dir([savefile_path 'nmm_' filename '/a' int2str(i-1)])) + mkdir([savefile_path 'nmm_' filename '/a' int2str(i-1)]) + end + + fn = [savefile_path 'raw_nmm/a' int2str(i-1) '/mean_iter_' int2str(iter) '_a_iter_' int2str(i-1)]; + raw_data = load([fn '_ds.mat']); + nmm = raw_data.all_data; + % nmm = downsample(nmm, 4); + [spike_time, spike_chan] = find_spike_time(nmm); % Process raw tvb output to find the spike peak time + + % ----------- select the spikes we want to extract ---------------% + rule1 = (spike_chan == i); % there is spike in the source region + start_time = floor(spike_time(rule1)/500) * 500 + 1; % there is no source in other region in the clip + clear_ind = repmat(start_time, [900, 1]) + (-200:699)'; % 900 * num_spike + rule2 = (sum(ismember(clear_ind, spike_time(~rule1)), 1) == 0); % there are no other spikes in the clip + spike_time = spike_time(rule1); + spike_time = spike_time(rule2); + + % ----------- Optional : Scale the NMM here----------------------% + alpha_value = find_alpha(nmm, fwd, i, spike_time, 15); + nmm = rescale_nmm_channel(nmm, i, spike_time, alpha_value); + % ------------Save Spike NMM Data --------------------------------% + start_time = floor(spike_time/500) * 500 + 1; + spike_ind = repmat(start_time, [500, 1]) + (0:499)'; +% start_time = floor((spike_time+200)/500) * 500 + 1 - 200; % start time can be changed +% start_time = max(start_time, 101); +% spike_ind = repmat(start_time, [500, 1]) + (0:499)'; + + nmm_data = reshape(nmm(spike_ind,:), 500, [], size(nmm,2)); % size: time * num_spike * channel + save_spikes_(nmm_data, [savefile_path 'nmm_' filename '/a' int2str(i-1) '/nmm_'], previous_iter_spike_num(i)); + previous_iter_spike_num(i) = previous_iter_spike_num(i) + length(spike_time); + % Save something in clip info, so that we can make sure we finish this process + save_struct = struct(); + save_struct.num_spike = previous_iter_spike_num(i); + save_struct.spike_time = spike_time; + parsave([savefile_path 'nmm_' filename '/clip_info/iter' int2str(iter) '/iter_' int2str(iter) '_i_' int2str(i-1) '.mat'], save_struct) + sprintf(['iter_' int2str(iter) '_i_%d is done\n'], i-1) + end % END REGION +end % END ITER +end % ENG FUNCTION + + + +%% --------------------- Helper functions ------------------------------ %% +function parsave(fname, mapObj) + save(fname,'-struct', 'mapObj'); +end + + +function save_spikes_(spike_data, savefile_path, previous_iter_spike_num) +% Save the spike data into seperate files +% INPUTS: spike_data: time * num_spikes * channel; extracted spike data +% savefile_path: string + for iii = 1:size(spike_data,2) + % The raw data + data = squeeze(spike_data(:,iii,:)); + save([savefile_path int2str(iii+previous_iter_spike_num) '.mat'], 'data', '-v7') + end +end + + +function [spike_time, spike_chan] = find_spike_time(nmm) +% Process raw tvb output to find the spike peak time. +% +% INPUTS: +% - nmm : (Downsampled) raw tvb output, time * channel +% OUTPUTS: +% - spike_time : the spike peak time in the (downsampled) NMM data +% - spike_chan : the spike channel for each spike + + spikes_nmm = nmm; + spikes_nmm(nmm < 8) = 0; % find the spiking activity stronger than the background + local_max = islocalmax(spikes_nmm); % find the peak + [spike_time, spike_chan] = find(local_max); + [spike_time, sort_ind] = sort(spike_time); + spike_chan = spike_chan(sort_ind); % sort the activity based on time + use_ind = (spike_time-249 > 0) & ... % ignore the spikes at the beginning or end of the signal + (spike_time+250 < size(nmm, 1) & ... + [1 diff(spike_time)'>100]'); % ignore peaks close together for now (will have signals with close peaks in multi-source condition) + spike_time = spike_time(use_ind)'; + spike_chan = spike_chan(use_ind)'; + +end + + + +function [alpha] = find_alpha(nmm, fwd, region_id, time_spike, target_SNR) +% Find the scaling factor for the NMM channels. +% +% INPUTS: +% - nmm : NMM data with single activation, time * channel +% - fwd : leadfield matrix, num_electrode * num_regions +% - region_id : source regions, start at 1, region_id(1) is the center +% - time_spike : spike peak time +% - target_SNR : set snr between signal and the background activity. +% OUTPUTS: +% - spike_time : the spike peak time in the (downsampled) NMM data +% - spike_chan : the spike channel for each spike +% - alpha : the scaling factor for one patch source + + spike_ind = repmat(time_spike, [200, 1]) + (-99:100)'; + spike_ind = min(max(spike_ind(:),0), size(nmm,1)); % make sure the index is not out of range +% spike_ind = max(0, time_spike-100): max(time_spike+100,size(nmm,1)); % make sure the index is not out of range + spike_shape = nmm(:,region_id(1)); %/max(nmm(:,region_id(1))); + nmm(:, region_id) = repmat(spike_shape,1,length(region_id)); + % calculate the scaling factor + [Ps, Pn, ~] = calcualate_SNR(nmm, fwd, region_id, spike_ind); + alpha = sqrt(10^(target_SNR/10)*Pn/Ps); +end + + +function scaled_nmm = rescale_nmm_channel(nmm, region_id, spike_time, alpha_value) +% Re-scaling NMM channels in source channels +% +% INPUTS: +% - nmm : NMM data with single activation, time * channel +% - spike_time : spike peak time +% - region_id : source regions, start at 1 +% - alpha_value: scaling factor +% OUTPUTS: +% - scaled_nmm : scaled NMM in the source region; time * channel + + nmm_rm = nmm - mean(nmm, 1); + for i=1:length(spike_time) + sig = nmm_rm(spike_time(i)-249:spike_time(i)+250, region_id); % one second data around the peak + + thre = 0.1; + small_ind = find(abs(sig)450) | (small_ind < 50)) = []; + start_ind = find((small_ind-250)<0); % spike start time + % test 1 + while isempty(start_ind) + thre = thre+0.05; + small_ind = find(abs(sig)450) | (small_ind < 50)) = []; + start_ind = find((small_ind-250)<0); + end + start_sig = small_ind(start_ind(end)); + + % test 2 + [~, min_ind] = min(sig(301:400)); + min_ind = min_ind + 301; + end_ind = find((small_ind-min_ind)>0); + while isempty(end_ind) + thre = thre+0.05; + small_ind = find(abs(sig)450) | (small_ind < 50)) = []; + end_ind = find((small_ind-min_ind)>0); + end + end_sig = small_ind(end_ind(1)); % spike end time + + sig(start_sig:end_sig) = sig(start_sig:end_sig) * alpha_value; % scale the signal + nmm_rm(spike_time(i)-249:spike_time(i)+250, region_id) = sig; + end + scaled_nmm = nmm_rm + mean(nmm, 1); + + +end + + +function [Ps, Pn, cur_snr] = calcualate_SNR(nmm, fwd, region_id, spike_ind) +% Caculate SNR at sensor space. +% +% INPUTS: +% - nmm : NMM data with single activation, time * channel +% - fwd : leadfield matrix, num_electrode * num_regions +% - region_id : source regions, start at 1, region_id(1) is the center +% - spike_ind : index to calculate the spike snr +% OUTPUTS: +% - Ps : signal power +% - Pn : noise power +% - cur_snr : current SNR in dB + + sig_eeg = (fwd(:, region_id)*nmm(:, region_id)')'; % time * channel + sig_eeg_rm = sig_eeg - mean(sig_eeg, 1); + dd = 1:size(nmm,2); + dd(region_id) = []; + noise_eeg = (fwd(:,dd)*nmm(:,dd)')'; + noise_eeg_rm = noise_eeg - mean(noise_eeg, 1); + + Ps = norm(sig_eeg_rm(spike_ind,:),'fro')^2/length(spike_ind); + Pn = norm(noise_eeg_rm(spike_ind,:),'fro')^2/length(spike_ind); + cur_snr = 10*log10(Ps/Pn); +end \ No newline at end of file diff --git a/loaders.py b/loaders.py new file mode 100644 index 0000000..c9f4818 --- /dev/null +++ b/loaders.py @@ -0,0 +1,291 @@ +from torch.utils.data import Dataset +import numpy as np +from scipy.io import loadmat, savemat +import h5py +from utils import add_white_noise, ispadding +import random +import mne + + +class SpikeEEGBuild(Dataset): + + """Dataset, generate input/output on the run + + Attributes + ---------- + data_root : str + Dataset file location + fwd : np.array + Size is num_electrode * num_region + data : np.array + TVB output data + dataset_meta : dict + Information needed to generate data + selected_region: spatial model for the sources; num_examples * num_sources * max_size + num_examples: num_examples in this dataset + num_sources: num_sources in one example + max_size: cortical regions in one source patch; first value is the center region id; variable length, padded to max_size + (set to 70, an arbitrary number) + nmm_idx: num_examples * num_sources: index of the TVB data to use as the source + scale_ratio: scale the waveform maginitude in source region; num_examples * num_sources * num_scale_ratio (num_snr_level) + mag_change: magnitude changes inside a source patch; num_examples * num_sources * max_size + weight decay inside a patch; equals to 1 in the center region; variable length; padded to max_size + sensor_snr: the Gaussian noise added to the sensor space; num_examples * 1; + + dataset_len : int + size of the dataset, can be set as a small value during debugging + """ + + def __init__(self, data_root, fwd, transform=None, args_params=None): + + # args_params: optional parameters; can be dataset_len + + self.file_path = data_root + self.fwd = fwd + self.transform = transform + + self.data = [] + self.dataset_meta = loadmat(self.file_path) + if 'dataset_len' in args_params: + self.dataset_len = args_params['dataset_len'] + else: # use the whole dataset + self.dataset_len = self.dataset_meta['selected_region'].shape[0] + if 'num_scale_ratio' in args_params: + self.num_scale_ratio = args_params['num_scale_ratio'] + else: + self.num_scale_ratio = self.dataset_meta['scale_ratio'].shape[2] + + def __getitem__(self, index): + + if not self.data: + self.data = h5py.File('{}_nmm.h5'.format(self.file_path[:-12]), 'r')['data'] + + raw_lb = self.dataset_meta['selected_region'][index].astype(np.int) # labels with padding + lb = raw_lb[np.logical_not(ispadding(raw_lb))] # labels without padding + raw_nmm = np.zeros((500, self.fwd.shape[1])) + + for kk in range(raw_lb.shape[0]): # iterate through number of sources + curr_lb = raw_lb[kk, np.logical_not(ispadding(raw_lb[kk]))] + current_nmm = self.data[self.dataset_meta['nmm_idx'][index][kk]] + + ssig = current_nmm[:, [curr_lb[0]]] # waveform in the center region + # set source space SNR + ssig = ssig / np.max(ssig) * self.dataset_meta['scale_ratio'][index][kk][random.randint(0, self.num_scale_ratio - 1)] + current_nmm[:, curr_lb] = ssig.reshape(-1, 1) + # set weight decay inside one source patch + weight_decay = self.dataset_meta['mag_change'][index][kk] + weight_decay = weight_decay[np.logical_not(ispadding(weight_decay))] + current_nmm[:, curr_lb] = ssig.reshape(-1, 1) * weight_decay + + raw_nmm = raw_nmm + current_nmm + + eeg = np.matmul(self.fwd, raw_nmm.transpose()) # project data to sensor space; num_electrode * num_time + csnr = self.dataset_meta['sensor_snr'][index] + noisy_eeg = add_white_noise(eeg, csnr).transpose() + + noisy_eeg = noisy_eeg - np.mean(noisy_eeg, axis=0, keepdims=True) # time + noisy_eeg = noisy_eeg - np.mean(noisy_eeg, axis=1, keepdims=True) # channel + noisy_eeg = noisy_eeg / np.max(np.abs(noisy_eeg)) + + # get the training output + empty_nmm = np.zeros_like(raw_nmm) + empty_nmm[:, lb] = raw_nmm[:, lb] + empty_nmm = empty_nmm / np.max(empty_nmm) + # Each data sample + sample = {'data': noisy_eeg.astype('float32'), + 'nmm': empty_nmm.astype('float32'), + 'label': raw_lb, + 'snr': csnr} + if self.transform: + sample = self.transform(sample) + + # savemat('{}/data{}.mat'.format(self.file_path[0][:-4],index),{'data':noisy_eeg,'label':raw_lb,'nmm':empty_nmm[:,lb]}) + return sample + + def __len__(self): + return self.dataset_len + + +class SpikeEEGLoad(Dataset): + + """Dataset, load pregenerated input/output pair + + Attributes + ---------- + data_root : str + Dataset file location + fwd : np.array + Size is num_electrode * num_region + dataset_len : int + size of the dataset, can be set as a small value during debugging + """ + + def __init__(self, data_root, fwd, transform=None, args_params=None): + + # args_params: optional parameters; can be dataset_len + + self.file_path = data_root + self.fwd = fwd + self.transform = transform + if 'dataset_len' in args_params: + self.dataset_len = args_params['dataset_len'] + else: # use the whole dataset + self.dataset_len = len(dir('{}/data*.mat')) + + def __getitem__(self, index): + + # load data saved as separate files using loadmat + raw_data = loadmat('{}/data{}'.format(self.file_path, index)) + sample = {'data': raw_data['data'].astype('float32'), + 'nmm': raw_data['nmm'].astype('float32'), + 'label': raw_data['label'], + 'snr': raw_data['csnr']} + + if self.transform: + sample = self.transform(sample) + + return sample + + def __len__(self): + return self.dataset_len + + +class SpikeEEGBuildEval(Dataset): + + """Dataset, generate test data under different conditions to evaluate the model under different conditions + + Attributes + ---------- + data_root : str + Dataset file location + fwd : np.array + Size is num_electrode * num_region + data : np.array + TVB output data + dataset_meta : dict + Information needed to generate data + selected_region: spatial model for the sources; num_examples * num_sources * max_size + num_examples: num_examples in this dataset + num_sources: num_sources in one example + max_size: cortical regions in one source patch; first value is the center region id; variable length, padded to max_size + (set to 70, an arbitrary number) + nmm_idx: num_examples * num_sources: index of the TVB data to use as the source + scale_ratio: scale the waveform maginitude in source region; num_examples * num_sources * num_scale_ratio (num_snr_level) + mag_change: magnitude changes inside a source patch; num_examples * num_sources * max_size + weight decay inside a patch; equals to 1 in the center region; variable length; padded to max_size + sensor_snr: the Gaussian noise added to the sensor space; num_examples * 1; + + dataset_len : int + size of the dataset, can be set as a small value during debugging + + eval_params : dict + New attributes compare to SpikeEEGBuild, depending on the test running, keys can be + lfreq : int; high pass cut-off frequency; filter EEG data to perform narrow-band analysis + hfreq : int; low pass cut-off frequency; filter EEG data to perform narrow-band analysis + snr_rsn_ratio: float; [0, 1]; ratio between real noise and gaussian noise + + + """ + + def __init__(self, data_root, fwd, transform=None, args_params=None): + + # args_params: optional parameters; can be dataset_len, num_scale_ratio + + self.file_path = data_root + self.fwd = fwd + self.transform = transform + + self.data = [] + self.dataset_meta = loadmat(self.file_path) + self.eval_params = dict() + + # check args_params: + if 'dataset_len' in args_params: + self.dataset_len = args_params['dataset_len'] + else: # use the whole dataset + self.dataset_len = self.dataset_meta['selected_region'].shape[0] + if 'num_scale_ratio' in args_params: + self.num_scale_ratio = args_params['num_scale_ratio'] + else: + self.num_scale_ratio = self.dataset_meta['scale_ratio'].shape[2] + + if 'snr_rsn_ratio' in args_params and args_params['snr_rsn_ratio']: # Need to add realistic noise + self.eval_params['rsn'] = loadmat('anatomy/realistic_noise.mat') + self.eval_params['snr_rsn_ratio'] = args_params['snr_rsn_ratio'] + if 'lfreq' in args_params and args_params['lfreq'] > 0: + if 'hfreq' in args_params and args_params['hfreq'] > 0: + self.eval_params['lfreq'] = args_params['lfreq'] + self.eval_params['hfreq'] = args_params['hfreq'] + else: + print('WARNING: NEED TO ASSIGN BOTH LOW-PASS AND HIGH-PASS CUT-OFF FREQ, IGNORE FILTERING') + + def __getitem__(self, index): + + if not self.data: + self.data = h5py.File('{}_nmm.h5'.format(self.file_path[:-12]), 'r')['data'] + + raw_lb = self.dataset_meta['selected_region'][index].astype(np.int) # labels with padding + lb = raw_lb[np.logical_not(ispadding(raw_lb))] # labels without padding + raw_nmm = np.zeros((500, self.fwd.shape[1])) + + for kk in range(raw_lb.shape[0]): # iterate through number of sources + curr_lb = raw_lb[kk, np.logical_not(ispadding(raw_lb[kk]))] + current_nmm = self.data[self.dataset_meta['nmm_idx'][index][kk]] + + ssig = current_nmm[:, [curr_lb[0]]] # waveform in the center region + # set source space SNR + ssig = ssig / np.max(ssig) * self.dataset_meta['scale_ratio'][index][kk][random.randint(0, self.num_scale_ratio - 1)] + current_nmm[:, curr_lb] = ssig.reshape(-1, 1) + # set weight decay inside one source patch + weight_decay = self.dataset_meta['mag_change'][index][kk] + weight_decay = weight_decay[np.logical_not(ispadding(weight_decay))] + current_nmm[:, curr_lb] = ssig.reshape(-1, 1) * weight_decay + + raw_nmm = raw_nmm + current_nmm + + eeg = np.matmul(self.fwd, raw_nmm.transpose()) # project data to sensor space; num_electrode * num_time + csnr = self.dataset_meta['sensor_snr'][index] + + # add noise to sensor space + if 'rsn' in self.eval_params: + noisy_eeg = add_white_noise(eeg, csnr, + {'ratio': self.eval_params['snr_rsn_ratio'], + 'rndata': self.eval_params['rsn']['data'], + 'rnpower': self.eval_params['rsn']['npower']}).transpose() + else: + noisy_eeg = add_white_noise(eeg, csnr).transpose() + + # filter data into narrow band + if 'lfreq' in self.eval_params: + noisy_eeg = mne.filter.filter_data(np.tile(noisy_eeg.transpose(),(1,5)), 500, self.eval_params['lfreq'], self.eval_params['hfreq'], + verbose=False).transpose() + noisy_eeg = noisy_eeg[1000:1500] + + noisy_eeg = noisy_eeg - np.mean(noisy_eeg, axis=0, keepdims=True) # time + noisy_eeg = noisy_eeg - np.mean(noisy_eeg, axis=1, keepdims=True) # channel + noisy_eeg = noisy_eeg / np.max(np.abs(noisy_eeg)) + + # get the training output + empty_nmm = np.zeros_like(raw_nmm) + empty_nmm[:, lb] = raw_nmm[:, lb] + empty_nmm = empty_nmm / np.max(empty_nmm) + # Each data sample + sample = {'data': noisy_eeg.astype('float32'), + 'nmm': empty_nmm.astype('float32'), + 'label': raw_lb, + 'snr': csnr} + if self.transform: + sample = self.transform(sample) + + # savemat('{}/data{}.mat'.format(self.file_path[0][:-4],index),{'data':noisy_eeg,'label':raw_lb,'nmm':empty_nmm[:,lb]}) + return sample + + def __len__(self): + return self.dataset_len + + +# from matplotlib import pyplot as plt +# plt.subplot(1,2,1) +# plt.plot(noisy_eeg) +# plt.subplot(1,2,2) +# plt.plot(empty_nmm[:,lb]) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..92bab23 --- /dev/null +++ b/main.py @@ -0,0 +1,196 @@ +import argparse +import os +import time +from scipy.io import loadmat, savemat +import numpy as np +import logging +import datetime + +import torch +from torch import optim +from torch.utils.data import DataLoader + +import network +import loaders + + +def main(): + start_time = time.time() + # parse the input + parser = argparse.ArgumentParser(description='DeepSIF Model') + parser.add_argument('--save', type=int, default=True, help='save each epoch or not') + parser.add_argument('--workers', default=0, type=int, help='number of data loading workers') + parser.add_argument('--batch_size', default=64, type=int, help='batch size') + parser.add_argument('--device', default='cuda:0', type=str, help='device running the code') + parser.add_argument('--arch', default='TemporalInverseNet', type=str, help='network achitecture class') + parser.add_argument('--dat', default='SpikeEEGBuild', type=str, help='data loader') + parser.add_argument('--train', default='test_sample_source2.mat', type=str, help='train dataset name or directory') + parser.add_argument('--test', default='test_sample_source2.mat', type=str, help='test dataset name or directory') + parser.add_argument('--model_id', default=75, type=int, help='model id') + parser.add_argument('--lr', default=3e-4, type=float, help='learning rate') + parser.add_argument('--resume', default='1', type=str, help='epoch id to resume') + parser.add_argument('--epoch', default=20, type=int, help='total number of epoch') + parser.add_argument('--fwd', default='leadfield_75_20k.mat', type=str, help='forward matrix to use') + parser.add_argument('--rnn_layer', default=3, type=int, help='number of rnn layer') + parser.add_argument('--info', default='', type=str, help='other information regarding this model') + args = parser.parse_args() + + # ======================= PREPARE PARAMETERS ===================================================================================================== + use_cuda = torch.cuda.is_available() + device = torch.device(args.device if use_cuda else "cpu") + + data_root = 'source/Simulation/' + result_root = 'model_result/{}_the_model'.format(args.model_id) + if not os.path.exists(result_root): + os.makedirs(result_root) + fwd = loadmat('anatomy/{}'.format(args.fwd))['fwd'] + + # Define logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + handler = logging.FileHandler(result_root + '/outputs_{}.log'.format(args.arch)) + handler.setLevel(logging.INFO) + logger.addHandler(handler) + logger.info("============================= {} ====================================".format(datetime.datetime.now())) + logger.info("Training data is {}, and testing data is {}".format(args.train, args.test)) + # Save every parameters in args + for v in args.__dict__: + if v not in ['workers', 'train', 'test']: + logger.info('{} is {}'.format(v, args.__dict__[v])) + + # ================================== LOAD DATA =================================================================================================== + train_data = loaders.__dict__[args.dat](data_root + args.train, fwd=fwd, + args_params={'dataset_len': 4}) + train_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=True) + test_data = loaders.__dict__[args.dat](data_root + args.test, fwd=fwd, args_params={'dataset_len': 4}) + test_loader = DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=False) + + # ================================== CREATE MODEL ================================================================================================ + + net = network.__dict__[args.arch](num_sensor=75, num_source=994, rnn_layer=args.rnn_layer, + spatial_model=network.MLPSpatialFilter, + temporal_model=network.TemporalFilter, + spatial_output='value_activation', temporal_output='rnn', spatial_activation='ELU', temporal_activation='ELU', + temporal_input_size=500).to(device) + optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-6) + # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True, threshold=0.001) + criterion = torch.nn.MSELoss(reduction='sum') + + args.start_epoch = 0 + best_result = np.Inf + train_loss = [] + test_loss = [] + + # =============================== RESUME ========================================================================================================= + if args.resume: + print("=> Load checkpoint", args.resume, "from", result_root) + fn = os.path.join(result_root, 'epoch_' + args.resume) + if os.path.isfile(fn): + print("=> Found checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(fn, map_location=torch.device('cpu')) + args.start_epoch = checkpoint['epoch'] + best_result = checkpoint['best_result'] + # recreate net and optimizer based on the saved model + net = network.__dict__[checkpoint['arch']](*checkpoint['attribute_list']).to(device) # redefine the weights architecture + net.load_state_dict(checkpoint['state_dict'], strict=False) + optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-6) + optimizer.load_state_dict(checkpoint['optimizer']) + # for param_group in optimizer.param_groups: + # param_group['lr'] = param_group['lr'] / checkpoint['lr'] * args.lr + print("=> Loaded checkpoint epoch {}, current results: {}".format(args.resume, best_result)) + tte = loadmat(result_root + '/train_test_error.mat') + train_loss.extend(tte['train_loss'][0][:int(args.resume) + 1]) + test_loss.extend(tte['test_loss'][0][:int(args.resume) + 1]) + else: + print("WARNING: no checkpoint found at '{}', use random weights".format(args.resume)) + + print('Number of parameters:', net.count_parameters()) + print('Prepare time:', time.time() - start_time) + + # =============================== TRAINING ======================================================================================================= + for epoch in range(args.start_epoch + 1, args.epoch): + + # train for one epoch + train_lss_all = train(train_loader, net, criterion, optimizer, {'device': device, 'logger': logger}) + # evaluate on validation set + test_lss_all = validate(test_loader, net, criterion, {'device': device}) + # lr_scheduler.step(test_le) + train_loss.extend([np.sum(np.array(train_lss_all)) / len(train_data)]) + test_loss.extend([np.sum(np.array(test_lss_all))/len(test_data)]) + + print_s = 'Epoch {}: Time:{:6.2f}, '.format(epoch, time.time() - start_time) + \ + 'Train Loss:{:06.5f}'.format(train_loss[-1]) + ', Test Loss:{:06.5f}'.format(test_loss[-1]) + logger.info(print_s) + print(print_s) + is_best = test_loss[-1] < best_result + best_result = min(test_loss[-1], best_result) + if is_best: + torch.save({ + 'epoch': epoch, 'arch': args.arch, 'state_dict': net.state_dict(), 'best_result': best_result, 'lr': args.lr, 'info': args.info, + 'train': args.train, 'test': args.test, 'attribute_list': net.attribute_list, 'optimizer': optimizer.state_dict()}, + result_root + '/model_best.pth.tar') + if args.save: + # save checkpoint + torch.save({ + 'epoch': epoch, 'arch': args.arch, 'state_dict': net.state_dict(), 'best_result': best_result, 'lr': args.lr, 'info': args.info, + 'train': args.train, 'test': args.test, 'attribute_list': net.attribute_list, 'optimizer': optimizer.state_dict()}, + result_root + '/epoch_{}'.format(epoch)) + savemat(result_root + '/train_test_error.mat', {'train_loss': train_loss, 'test_loss': test_loss}) + savemat(result_root + '/train_test_loss_epoch{}.mat'.format(epoch), {'train_loss': train_lss_all, 'test_loss': test_lss_all}) + # END MAIN_TRAIN + + +# START TRAIN FUNC +def train(train_loader, model, criterion, optimizer, args_params): + # args_params: potential parameter inputs, could be "device","logger" + + device = args_params['device'] + logger = args_params['logger'] + # switch to train mode + model.train() + train_loss = [] + start_time = time.time() + for batch_idx, sample_batch in enumerate(train_loader): + # load data + data = sample_batch['data'].to(device) + nmm = sample_batch['nmm'].to(device) + + # training process + optimizer.zero_grad() + model_output = model(data) + out = model_output['last'] + loss = criterion(out, nmm) + loss.backward() + optimizer.step() + + train_loss.append(loss.data.view(1)) + if (batch_idx + 1) % 500 == 0: + print_s = "batch_idx_{}_time_{}_train_loss_{}".format(batch_idx, time.time() - start_time, train_loss[-1]) + logger.info(print_s) + train_loss = torch.cat(train_loss).cpu().numpy() + return train_loss +# END TRAIN + + +# START VALIDATE FUNC +def validate(val_loader, model, criterion, args_params): + # switch to evaluate mode + device = args_params['device'] + model.eval() + val_loss = [] + with torch.no_grad(): + for batch_idx, sample_batch in enumerate(val_loader): + data = sample_batch['data'].to(device) + nmm = sample_batch['nmm'].to(device) + model_output = model(data) + out = model_output['last'] + loss = criterion(out, nmm) + val_loss.append(loss.data.view(1)) + val_loss = torch.cat(val_loss).cpu().numpy() + return val_loss +# END VALIDATE + + +if __name__ == '__main__': + main() + diff --git a/main_train.exe b/main_train.exe deleted file mode 100644 index 1667b40..0000000 Binary files a/main_train.exe and /dev/null differ diff --git a/misc_scripts/README.md b/misc_scripts/README.md new file mode 100644 index 0000000..d31d12b --- /dev/null +++ b/misc_scripts/README.md @@ -0,0 +1,2 @@ +* ```forward_prepare_lfd_matrix```: Calculate the forward matrix based on the cortical segmentation. +* ```eval_calculate_sim_results```: Evaluate estimated sources based on saved resource after performing otsu threshold. diff --git a/misc_scripts/eval_calculate_sim_results.m b/misc_scripts/eval_calculate_sim_results.m new file mode 100644 index 0000000..d6faf3f --- /dev/null +++ b/misc_scripts/eval_calculate_sim_results.m @@ -0,0 +1,81 @@ +% Evaluate estimated sources based on saved resource after performing otsu +% threshold. +clear +dataset_name = '_sample_source2'; +fname = '_sample_source2'; +model_id = '75'; +gt = load(['../source/Simulation/test' dataset_name '.mat']); % ground truth +num_source = size(gt.selected_region, 2); +load(['../model_result/' model_id '_the_model/model_best.pth.tar_preds_test' fname '.mat']); + +load('../anatomy/dis_matrix_fs_20k.mat') +all_dis = raw_dis_matrix; +% Variabled loaded: +% FROM source +% gt.selected_region : array; ground truth source region, start from 0, num_examples * num_source * MAX_SIZE +% (MAX_SIZE=70, the number of cortical regions in each +% example is different, padd with 15213 to size 70 +% FROM model_result +% all_regions : cell; DeepSIF reconstructed source region, start from 0, num_examples * 1 +% all_out : cell; activity in DeepSIF reconstructed source region; num_examples * 1 +% all_num : array: ground truth activity; num_examples * num_source * num_time +% FROM anatomy +% all_dis : distance between cortical source regions + +%% +precision = nan(length(all_out), num_source); +recall = nan(length(all_out), num_source); +le = nan(length(all_out), num_source); +all_corr = nan(length(all_out), num_source); + +recon_regions = cell(length(all_out), num_source); +recon_activity = cell(length(all_out), num_source); + +for i= 1:length(all_out) + + % gather all source regions, remove padded variable. + all_label = reshape(squeeze(gt.selected_region(i,:,:))',[],num_source); + num_region_per_source = sum(~myisnan(all_label),1); + all_label(myisnan(all_label)) = []; + % recon regions + current_regions = all_regions{i}; + if isempty(current_regions) + continue + end + + % assign the each recon source region to its closest source patch + [~, min_ind] = min(all_dis(all_label+1,current_regions+1),[],1); + mapping = []; % size: 1 * total_ground_truth_source_regions + for ii=1:length(num_region_per_source) + mapping = [mapping ii*ones(1,num_region_per_source(ii))]; + end + source_id = mapping(min_ind); + + % calculate metrics + for k=1:max(source_id) + recon = current_regions(source_id==k); + lb = squeeze(gt.selected_region(i,k,:)); + lb(myisnan(lb)) = []; + if ~isempty(recon) + recon_regions{i,k} = recon; + recon_activity{i,k} = all_out{i}(source_id==k,:); + + interc = intersect(recon,lb); + precision(i,k) = length(interc)/length(recon); + recall(i,k) = length(interc)/length(lb); + all_corr(i,k) = corr(mean(all_out{i}(source_id==k,:),1)',squeeze(all_nmm(i,k,:))); + le(i,k) = mean(min(all_dis(recon+1,lb+1),[],2)); + end + end + +end +% save and display results +s(1) = mean(precision(:), 'omitnan');s(2) = mean(recall(:),'omitnan');s(3) = mean(all_corr(:),'omitnan');s(4) = mean(le(:),'omitnan');s +save(['../model_result/' model_id '_the_model/recon' fname '.mat'],'precision','recall','le','all_corr','recon_regions','recon_activity') +%% +function y = myisnan(x) + y = abs(x-15213)<1e-6; +end + + + diff --git a/misc_scripts/forward_prepare_lfd_matrix.m b/misc_scripts/forward_prepare_lfd_matrix.m new file mode 100644 index 0000000..240a2c1 --- /dev/null +++ b/misc_scripts/forward_prepare_lfd_matrix.m @@ -0,0 +1,41 @@ +% Calculate the forward matrix based on the cortical segmentation. +% Assume the headmodel was exported from Brainstorm as bs_headmodel + +assert(size(bs_headmodel.Gain, 2) == length(rm)*3) + +lfd_free = bs_headmodel.Gain; +lfd = lfd_free2fix(lfd_free,bs_headmodel.GridOrient); +fwd = fwd_to_rmfwd(lfd, rm); + +save('../anatomy/leadfield_test.mat','fwd') + +function fwd = lfd_free2fix(lfd, ori) +% Transfer the free direction leadfield to fix direction +% INPUTS: +% - lfd : free direction leadfield, num_electrode * (num_vertice*3) +% - ori : orientation, num_vertice * 3 +% OUTPUTS: +% - fwd : free direction leadfield, num_electrode * num_vertice + +num_vertices = size(lfd,2)/3; +fwd = zeros(size(lfd,1), num_vertices); +for i = 1:num_vertices + fwd(:,i) = lfd(:, (i-1)*3+1:i*3) * ori(i,:)'; +end +end + +function new_fwd = fwd_to_rmfwd(fwd, rm) +% Calculate the sumed forward matrix for each region, assume each dipole in +% the source region has the same activity +% INPUTS: +% - fwd : fix direction leadfield, num_electrode * num_vertice +% - rm : region mapping array, map each vertice to a region, num_vertic*1 +% OUTPUTS: +% - new_fwd : leadfield for each region, num_electrode * num_region + +unique_rm = unique(rm); +new_fwd = zeros(size(fwd,1), length(unique_rm)); +for i=1:length(unique_rm) + new_fwd(:,i) = sum(fwd(:, rm==unique_rm(i)),2); +end +end \ No newline at end of file diff --git a/network.py b/network.py new file mode 100644 index 0000000..34dd01f --- /dev/null +++ b/network.py @@ -0,0 +1,71 @@ +from torch import nn + + +class MLPSpatialFilter(nn.Module): + + def __init__(self, num_sensor, num_hidden, activation): + super(MLPSpatialFilter, self).__init__() + self.fc11 = nn.Linear(num_sensor, num_sensor) + self.fc12 = nn.Linear(num_sensor, num_sensor) + self.fc21 = nn.Linear(num_sensor, num_hidden) + self.fc22 = nn.Linear(num_hidden, num_hidden) + self.fc23 = nn.Linear(num_sensor, num_hidden) + self.value = nn.Linear(num_hidden, num_hidden) + self.activation = nn.__dict__[activation]() + + def forward(self, x): + out = dict() + x = self.activation(self.fc12(self.activation(self.fc11(x))) + x) + x = self.activation(self.fc22(self.activation(self.fc21(x))) + self.fc23(x)) + out['value'] = self.value(x) + out['value_activation'] = self.activation(out['value']) + return out + + +class TemporalFilter(nn.Module): + + def __init__(self, input_size, num_source, num_layer, activation): + super(TemporalFilter, self).__init__() + self.rnns = nn.ModuleList() + self.rnns.append(nn.LSTM(input_size, num_source, batch_first=True, num_layers=num_layer)) + self.num_layer = num_layer + self.input_size = input_size + self.activation = nn.__dict__[activation]() + + def forward(self, x): + out = dict() + # c0/h0 : num_layer, T, num_out + for l in self.rnns: + l.flatten_parameters() + x, _ = l(x) + + out['rnn'] = x # seq_len, batch, num_directions * hidden_size + return out + + +class TemporalInverseNet(nn.Module): + + def __init__(self, num_sensor=64, num_source=994, rnn_layer=1, + spatial_model=MLPSpatialFilter, temporal_model=TemporalFilter, + spatial_output='value_activation', temporal_output='rnn', + spatial_activation='ReLU', temporal_activation='ReLU', temporal_input_size=500): + super(TemporalInverseNet, self).__init__() + self.attribute_list = [num_sensor, num_source, rnn_layer, + spatial_model, temporal_model, spatial_output, temporal_output, + spatial_activation, temporal_activation, temporal_input_size] + self.spatial_output = spatial_output + self.temporal_output = temporal_output + # Spatial filtering + self.spatial = spatial_model(num_sensor, temporal_input_size, spatial_activation) + # Temporal filtering + self.temporal = temporal_model(temporal_input_size, num_source, rnn_layer, temporal_activation) + + def forward(self, x): + out = dict() + out['fc2'] = self.spatial(x)[self.spatial_output] + x = out['fc2'] + out['last'] = self.temporal(x)[self.temporal_output] + return out + + def count_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..42bebf5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch >= 1.6.0 +numpy +mne +h5py +tvb \ No newline at end of file diff --git a/source/VEP/data1.mat b/source/VEP/data1.mat new file mode 100644 index 0000000..c162a42 Binary files /dev/null and b/source/VEP/data1.mat differ diff --git a/source/VEP/rnn_test_64_.pth.tar.mat b/source/VEP/rnn_test_64_.pth.tar.mat new file mode 100644 index 0000000..a5404ef Binary files /dev/null and b/source/VEP/rnn_test_64_.pth.tar.mat differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f07c40d --- /dev/null +++ b/utils.py @@ -0,0 +1,101 @@ +import numpy as np +from skimage.filters import threshold_otsu + + +def ispadding(x): + # identify the padding in array + return np.abs(x - 15213) < 1e-6 + + +def get_otsu_regions(out, labels, args_params = None): + """ Identify DeepSIF source region using otsu_threshould, run on CPU + :param out: np.arrry; the output of DeepSIF, batch_size * num_time * num_region + :param labels: np.arrry; group truth source region; batch_size * num_source * max_size; starts from 0 + :param args_params: optional parameters, could be + dis_matrix: np.array; distance between regions; num_region (994) * num_region + :return return_eval: could be + all_regions: DeepSIF predicted regions; (batch_size, ) + all_out: DeepSIF predicted source activity; (batch_size, ) + """ + # when there is no spike, the location error is nan + + batch_size = labels.shape[0] + return_eval = dict() + + return_eval['all_regions'] = np.empty((batch_size,), dtype=object) + return_eval['all_out'] = np.empty((batch_size,), dtype=object) + + for i in range(batch_size): + thre_source = np.abs(out[i]) + thre_source = (thre_source - np.min(thre_source)) / np.max(thre_source) + thresh = threshold_otsu(thre_source, nbins=100) + select_pixel = out[i] > thresh + otsu_region = np.where(np.sum(select_pixel, axis=0) > 7)[0] + return_eval['all_regions'][i] = otsu_region + return_eval['all_out'][i] = out[i, :, otsu_region] + + # Calculate the eval metrics in Python overall condition for all sources + if args_params is not None: + return_eval['precision'] = np.zeros(batch_size) + return_eval['recall'] = np.zeros(batch_size) + return_eval['le'] = np.zeros(batch_size) + for i in range(batch_size): + lb = labels[i][np.logical_not(ispadding(labels[i]))] + recon = return_eval['all_regions'][i] + overlap_region = len(np.intersect1d(lb, recon)) + # number of region based precision and recall + return_eval['precision'][i] = overlap_region/len(recon) + return_eval['recall'][i] = overlap_region / len(lb) + le_each_region = np.min(args_params['dis_matrix'][recon][:, lb], axis = 1) + return_eval['le'][i] = np.mean(le_each_region) + + return return_eval + + +def add_white_noise(sig, snr, args_params=None): + """ + :param sig: np.array; num_electrode * num_time + :param snr: int; signal to noise level in dB + :param args_params: optional parameters, could be + ratio: np.array; ratio between white Gaussian noise and pre-set realistic noise + rndata: np.array; realistic noise data; num_sample * num_electrode * num_time + rnpower: np.array; pre-calculated power for rndata; num_sample * num_electrode + + :return: noise_sig: np.array; num_electrode * num_time + """ + + num_elec, num_time = sig.shape + noise_sig = np.zeros((num_elec, num_time)) + sig_power = np.square(np.linalg.norm(sig, axis=1))/num_time + if args_params is None: + # Only add Gaussian noise + for i in range(num_elec): + noise_power = 10 ** (-(snr / 10)) * sig_power[i] / 2 + noise_std = np.sqrt(noise_power) + noise_sig[i, :] = sig[i, :] + np.random.normal(0, noise_std, (num_time,)) + else: + # Add realistic and Gaussian noise + rnpower = args_params['rnpower']/num_time + rndata = args_params['rndata'] + select_id = np.random.randint(0, rndata.shape[0]) + for i in range(num_elec): + noise_power = 10 ** (-(snr / 10)) * sig_power[i] + rpower = args_params['ratio']*noise_power # realistic noise power + noise_std = np.sqrt(noise_power - rpower) + noise_sig[i, :] = sig[i, :] + np.random.normal(0, noise_std, (num_time,)) + np.sqrt(rpower/rnpower[select_id][i])*rndata[select_id][:, i] + return noise_sig + + +def fwdJ_to_cortexJ(recon, rm): + """ + :param recon: np.array; DeepSIF output, (num_time, num_region) + :param rm: np.array; region mapping for each index, (num_vertices, ) + :return: J: np.array; DeepSIF output for each vertices, (num_time, num_vertices) + """ + num_time, num_region = recon.shape + num_vertices = rm.shape[0] + J = np.zeros((num_time, num_vertices)) + for k in range(num_time): + for i in range(num_region): + J[k, rm==i] = recon[k, i] + return J \ No newline at end of file