Skip to content

Commit

Permalink
test scripts unmodified
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmicha committed Dec 11, 2023
1 parent 78ea24f commit c3656c0
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 58 deletions.
6 changes: 2 additions & 4 deletions scripts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
default=1, help='Size of the max pooling operation.')
parser.add_argument('--modes', dest='modes', type=int,
default=30, help='Normal modes into consideration.')
parser.add_argument('--regions', dest='regions', type=str,
default='paired_hl', help='Choose between paired_hl (heavy chain, light chain and their interactions) and heavy (heavy chain only).')
parser.add_argument('--learning_rate', dest='learning_rate', type=float,
default=4e-4, help='Step size at each iteration.')
parser.add_argument('--n_max_epochs', dest='n_max_epochs', type=int,
Expand Down Expand Up @@ -79,9 +77,9 @@ def main(args):
test_losses.extend(test_loss)

## Saving Neural Network checkpoint
path = CHECKPOINTS_DIR + 'model_' + regions + '_epochs_' + str(n_max_epochs) + '_modes_' + str(modes) + '_pool_' + str(pooling_size) + '_filters_' + str(n_filters) + '_size_' + str(filter_size) + '.pt'
path = CHECKPOINTS_DIR + 'model_epochs_' + str(n_max_epochs) + '_modes_' + str(modes) + '_pool_' + str(pooling_size) + '_filters_' + str(n_filters) + '_size_' + str(filter_size) + '.pt'
save_checkpoint(path, model, optimiser, train_losses, test_losses)
np.save(CHECKPOINTS_DIR+'learnt_filter_'+regions+'_epochs_'+str(n_max_epochs)+'_modes_'+str(modes)+'_pool_'+str(pooling_size)+'_filters_'+str(n_filters)+'_size_'+str(filter_size)+'.npy', learnt_filter)
np.save(CHECKPOINTS_DIR+'learnt_filter_epochs_'+str(n_max_epochs)+'_modes_'+str(modes)+'_pool_'+str(pooling_size)+'_filters_'+str(n_filters)+'_size_'+str(filter_size)+'.npy', learnt_filter)

if __name__ == '__main__':
main(arguments)
5 changes: 2 additions & 3 deletions tests/alphafold_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def setUp(self):
self.mode = 'fully-extended' # Choose between 'fully-extended' and 'fully-cropped'
self.pathological = ['5omm', '1mj7', '1qfw', '1qyg', '4ffz', '3ifl', '3lrh', '3pp4', '3ru8', '3t0w', '3t0x', '4fqr', '4gxu', '4jfx', '4k3h', '4jfz', '4jg0', '4jg1', '4jn2', '4o4y', '4qxt', '4r3s', '4w6y', '4w6y', '5ies', '5ivn', '5j57', '5kvd', '5kzp', '5mes', '5nmv', '5sy8', '5t29', '5t5b', '5vag', '3etb', '3gkz', '3uze', '3uzq', '4f9l', '4gqp', '4r2g', '5c6t']
self.stage = 'predicting'
self.regions = 'paired_hl'
self.data_path = 'data/'
self.test_data_path = os.path.join('notebooks/', 'test_data/')
self.test_dccm_map_path = 'dccm_map/'
Expand All @@ -37,11 +36,11 @@ def test_alphafold(self):

for test_pdb, h_offset, l_offset in zip(self.test_pdbs, self.h_offset_list, self.l_offset_list):

preprocessed_data = Preprocessing(data_path=self.data_path, modes=self.modes, pathological=self.pathological, mode=self.mode, stage=self.stage, regions=self.regions, test_data_path=self.test_data_path, test_dccm_map_path=self.test_dccm_map_path, test_residues_path=self.test_residues_path, test_structure_path=self.test_structure_path, test_pdb_id=test_pdb+'_af', alphafold=True, h_offset=h_offset, l_offset=l_offset)
preprocessed_data = Preprocessing(data_path=self.data_path, modes=self.modes, pathological=self.pathological, mode=self.mode, stage=self.stage, test_data_path=self.test_data_path, test_dccm_map_path=self.test_dccm_map_path, test_residues_path=self.test_residues_path, test_structure_path=self.test_structure_path, test_pdb_id=test_pdb+'_af', alphafold=True, h_offset=h_offset, l_offset=l_offset)
self.af_pred_structures.append(preprocessed_data.test_x)
input_shape = preprocessed_data.test_x.shape[-1]

path = 'checkpoints/model_' + self.regions + '_epochs_' + str(self.n_max_epochs) + '_modes_' + str(self.modes) + '_pool_' + str(self.pooling_size) + '_filters_' + str(self.n_filters) + '_size_' + str(self.filter_size) + '.pt'
path = 'checkpoints/model_epochs_' + str(self.n_max_epochs) + '_modes_' + str(self.modes) + '_pool_' + str(self.pooling_size) + '_filters_' + str(self.n_filters) + '_size_' + str(self.filter_size) + '.pt'
model = load_checkpoint(path, input_shape)[0]
model.eval()

Expand Down
2 changes: 1 addition & 1 deletion tests/biology_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def setUp(self):
self.path = TEST_PATH

def test_biology(self):
get_cdr_lengths(['1fl6', '4fab', '5d70', '1kxt', '1g6v', '2p44', '2jb6', '6b9j'])
get_cdr_lengths(['1fl6', '4fab', '5d70', '1kxt', '1g6v', '2p44', '2jb6', '6b9j'], 'data/')
get_types_of_residues(['1t66', '1kel', '6mlb', 'abcd'])
3 changes: 1 addition & 2 deletions tests/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def setUp(self):
self.pooling_size = 1
self.n_max_epochs = 552
self.data_path = 'data/'
self.regions = 'paired_hl'
self.df = 'summary.tsv'
self.test_data_path = os.path.join('notebooks/', 'test_data/')
self.test_dccm_map_path = 'dccm_map/'
Expand All @@ -30,7 +29,7 @@ def setUp(self):

def test_evaluation(self):

preprocessed_data = Preprocessing(data_path=self.data_path, modes=self.modes, pathological=self.pathological, regions=self.regions, stage=self.stage, test_data_path=self.test_data_path, test_dccm_map_path=self.test_dccm_map_path, test_residues_path=self.test_residues_path, test_structure_path=self.test_structure_path)
preprocessed_data = Preprocessing(data_path=self.data_path, modes=self.modes, pathological=self.pathological, stage=self.stage, test_data_path=self.test_data_path, test_dccm_map_path=self.test_dccm_map_path, test_residues_path=self.test_residues_path, test_structure_path=self.test_structure_path)
input_shape = preprocessed_data.test_x.shape[-1]

# Testing cases of load checkpoint
Expand Down
38 changes: 16 additions & 22 deletions tests/explaining_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# ANTIPASTI
from antipasti.preprocessing.preprocessing import Preprocessing
from antipasti.utils.explaining_utils import compute_change_in_kd, compute_umap, get_epsilon, get_maps_of_interest, get_test_contribution, map_residues_to_regions, plot_map_with_regions
from antipasti.utils.explaining_utils import compute_umap, get_maps_of_interest, get_test_contribution, plot_map_with_regions
from antipasti.utils.torch_utils import load_checkpoint
from tests import TEST_PATH

Expand All @@ -21,7 +21,6 @@ def setUp(self):
self.mode = 'fully-extended' # Choose between 'fully-extended' and 'fully-cropped'
self.pathological = ['5omm', '1mj7', '1qfw', '1qyg', '4ffz', '3ifl', '3lrh', '3pp4', '3ru8', '3t0w', '3t0x', '4fqr', '4gxu', '4jfx', '4k3h', '4jfz', '4jg0', '4jg1', '4jn2', '4o4y', '4qxt', '4r3s', '4w6y', '4w6y', '5ies', '5ivn', '5j57', '5kvd', '5kzp', '5mes', '5nmv', '5sy8', '5t29', '5t5b', '5vag', '3etb', '3gkz', '3uze', '3uzq', '4f9l', '4gqp', '4r2g', '5c6t']
self.stage = 'predicting'
self.regions = 'paired_hl'
self.data_path = 'data/'
self.test_data_path = os.path.join('notebooks/', 'test_data/')
self.test_dccm_map_path = 'dccm_map/'
Expand All @@ -36,34 +35,29 @@ def test_explaining(self):
'Other': 2}

# Pre-processing
preprocessed_data = Preprocessing(data_path=self.data_path, modes=self.modes, pathological=self.pathological, regions=self.regions, mode=self.mode, stage=self.stage, test_data_path=self.test_data_path, test_dccm_map_path=self.test_dccm_map_path, test_residues_path=self.test_residues_path, test_structure_path=self.test_structure_path)
preprocessed_data = Preprocessing(data_path=self.data_path, modes=self.modes, pathological=self.pathological, mode=self.mode, stage=self.stage, test_data_path=self.test_data_path, test_dccm_map_path=self.test_dccm_map_path, test_residues_path=self.test_residues_path, test_structure_path=self.test_structure_path)
input_shape = preprocessed_data.test_x.shape[-1]

# Loading the actual checkpoint and learnt filters
path = 'checkpoints/model_' + self.regions + '_epochs_' + str(self.n_max_epochs) + '_modes_' + str(self.modes) + '_pool_' + str(self.pooling_size) + '_filters_' + str(self.n_filters) + '_size_' + str(self.filter_size) + '.pt'
path = 'checkpoints/model_epochs_' + str(self.n_max_epochs) + '_modes_' + str(self.modes) + '_pool_' + str(self.pooling_size) + '_filters_' + str(self.n_filters) + '_size_' + str(self.filter_size) + '.pt'
model = load_checkpoint(path, input_shape)[0]
learnt_filter = np.load('checkpoints/learnt_filter_'+self.regions+'_epochs_'+str(self.n_max_epochs)+'_modes_'+str(self.modes)+'_pool_'+str(self.pooling_size)+'_filters_'+str(self.n_filters)+'_size_'+str(self.filter_size)+'.npy')
learnt_filter = np.load('checkpoints/learnt_filter_epochs_'+str(self.n_max_epochs)+'_modes_'+str(self.modes)+'_pool_'+str(self.pooling_size)+'_filters_'+str(self.n_filters)+'_size_'+str(self.filter_size)+'.npy')
model.eval()

mean_learnt, mean_image, mean_diff_image = get_maps_of_interest(preprocessed_data, learnt_filter)
plot_map_with_regions(preprocessed_data, mean_learnt, 'Average of learnt representations', True)

contribution = get_test_contribution(preprocessed_data, model)
epsilon = get_epsilon(preprocessed_data, model)
epsilon = get_epsilon(preprocessed_data, model, mode='extreme')
coord, maps, labels = map_residues_to_regions(preprocessed_data, epsilon)
get_test_contribution(preprocessed_data, model)

# Expressing weights as vector
weights_h = [0.1, 0.1, 0, 0.1, 0, 0.1, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0.1]
weights_l = [0.1, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0.1]
weights = np.array(weights_h + weights_l)

compute_change_in_kd(preprocessed_data, model, weights, coord, maps)

compute_umap(preprocessed_data, model, scheme='heavy_species', regions='paired_hl', categorical=True, include_ellipses=True, numerical_values=None, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='heavy_subclass', regions='heavy', categorical=True, include_ellipses=False, numerical_values=None, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='antigen_species', regions='paired_hl', categorical=True, include_ellipses=False, numerical_values=None, external_cdict=cdict, interactive=True)
compute_umap(preprocessed_data, model, scheme='Random sequence', regions='paired_hl', categorical=False, include_ellipses=False, numerical_values=list(np.linspace(0, 1, num=preprocessed_data.train_x.shape[0])), external_cdict=None, interactive=True)


random_sequence = list(np.linspace(0, 1, num=preprocessed_data.train_x.shape[0]))
random_sequence_delete = random_sequence.copy()
random_sequence_delete[0] = 'unknown'
random_sequence = [str(random_sequence[0])] + random_sequence[1:]

compute_umap(preprocessed_data, model, scheme='heavy_species', categorical=True, include_ellipses=True, numerical_values=None, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='light_ctype', categorical=True, include_ellipses=True, numerical_values=None, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='light_subclass', categorical=True, include_ellipses=False, numerical_values=None, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='antigen_type', categorical=True, include_ellipses=False, numerical_values=None, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='antigen_species', categorical=True, include_ellipses=False, numerical_values=None, external_cdict=cdict, interactive=True)
compute_umap(preprocessed_data, model, scheme='Random sequence', categorical=False, include_ellipses=False, numerical_values=random_sequence, external_cdict=None, interactive=True)
compute_umap(preprocessed_data, model, scheme='Random sequence', categorical=False, include_ellipses=False, numerical_values=random_sequence_delete, external_cdict=None, interactive=True)
27 changes: 1 addition & 26 deletions tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,4 @@ def test_training_paired_hl(self):

# Saving the checkpoint
path = 'checkpoints/model_unit_test.pt'
save_checkpoint(path, model, optimiser, train_losses, test_losses)

def test_training_heavy(self):
preprocessed_data = Preprocessing(data_path=self.data_path, structures_path=self.structures_path, scripts_path=self.scripts_path, df=self.df, regions='heavy', pathological=self.pathological)
train_x, test_x, train_y, test_y, _, _ = create_test_set(preprocessed_data.train_x, preprocessed_data.train_y, test_size=0.5)

n_filters = 3
filter_size = 5
pooling_size = 2
learning_rate = 4e-4
n_max_epochs = 10
max_corr = 0.87
batch_size = 1
input_shape = preprocessed_data.train_x.shape[-1]

model = ANTIPASTI(n_filters=n_filters, filter_size=filter_size, pooling_size=pooling_size, input_shape=input_shape)
criterion = MSELoss()
optimiser = AdaBelief(model.parameters(), lr=learning_rate, eps=1e-8, print_change_log=False)

train_losses = []
test_losses = []
train_loss, test_loss, _, _, _ = training_routine(model, criterion, optimiser, train_x, test_x, train_y, test_y, n_max_epochs=n_max_epochs, max_corr=max_corr, batch_size=batch_size, verbose=False)

# Saving the losses
train_losses.extend(train_loss)
test_losses.extend(test_loss)
save_checkpoint(path, model, optimiser, train_losses, test_losses)

0 comments on commit c3656c0

Please sign in to comment.