From e4e5062170b73a8c1c74f49e3b9695c969c55b77 Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Tue, 3 Sep 2024 13:08:52 +0200 Subject: [PATCH] Compress features if --compress-features. Added corresponding tests --- .../scripts/create_individual_features.py | 31 ++- alphapulldown/utils/save_meta_data.py | 5 +- test/test_features_with_templates.py | 196 ++++++++++-------- 3 files changed, 130 insertions(+), 102 deletions(-) diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index 940106b4..9a431bc5 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -2,7 +2,8 @@ # coding: utf-8 # Create features for AlphaFold from fasta file(s) or a csv file with descriptions for multimeric templates # # - +import json +import lzma import os import pickle import sys @@ -40,6 +41,9 @@ flags.DEFINE_boolean("use_hhsearch", False, "Use hhsearch instead of hmmsearch when looking for structure template. Default is False") +flags.DEFINE_boolean("compress_features", False, + "Compress features.pkl and meta.json files using lzma algorithm. Default is False") + # Flags related to TrueMultimer flags.DEFINE_string("path_to_mmt", None, "Path to directory with multimeric template mmCIF files") @@ -229,10 +233,18 @@ def create_and_save_monomer_objects(monomer, pipeline): return # Save metadata - metadata_output_path = os.path.join(FLAGS.output_dir, - f"{monomer.description}_feature_metadata_{datetime.date(datetime.now())}.json") - with save_meta_data.output_meta_file(metadata_output_path) as meta_data_outfile: - save_meta_data.save_meta_data(flags_dict, meta_data_outfile) + meta_dict = save_meta_data.get_meta_dict(flags_dict) + metadata_output_path = os.path.join( + FLAGS.output_dir, + f"{monomer.description}_feature_metadata_{datetime.now().date()}.json" + ) + + if FLAGS.compress_features: + with lzma.open(metadata_output_path + '.xz', "wt") as meta_data_outfile: + json.dump(meta_dict, meta_data_outfile) + else: + with open(metadata_output_path, "w") as meta_data_outfile: + json.dump(meta_dict, meta_data_outfile) # Create features if FLAGS.use_mmseqs2: @@ -250,8 +262,13 @@ def create_and_save_monomer_objects(monomer, pipeline): ) # Save the processed monomer object - with open(pickle_path, "wb") as pickle_file: - pickle.dump(monomer, pickle_file) + if FLAGS.compress_features: + pickle_path = pickle_path + ".xz" + with lzma.open(pickle_path, "wb") as pickle_file: + pickle.dump(monomer, pickle_file) + else: + with open(pickle_path, "wb") as pickle_file: + pickle.dump(monomer, pickle_file) # Optional: Clear monomer from memory if necessary del monomer diff --git a/alphapulldown/utils/save_meta_data.py b/alphapulldown/utils/save_meta_data.py index 50d78533..d7ee6274 100644 --- a/alphapulldown/utils/save_meta_data.py +++ b/alphapulldown/utils/save_meta_data.py @@ -111,7 +111,7 @@ def get_metadata_for_database(k, v): return {} -def save_meta_data(flag_dict, outfile): +def get_meta_dict(flag_dict): """Save metadata in JSON format.""" metadata = { "databases": {}, @@ -139,8 +139,7 @@ def save_meta_data(flag_dict, outfile): "location_url": url} }) - with open(outfile, "w") as f: - json.dump(metadata, f, indent=2) + return metadata def get_last_modified_date(path): diff --git a/test/test_features_with_templates.py b/test/test_features_with_templates.py index 053da6cd..e0a66c8b 100644 --- a/test/test_features_with_templates.py +++ b/test/test_features_with_templates.py @@ -1,21 +1,23 @@ +import os import shutil import pickle import tempfile import subprocess from pathlib import Path +import lzma +import json import numpy as np -from absl.testing import absltest +from absl.testing import absltest, parameterized from alphapulldown.utils.remove_clashes_low_plddt import extract_seqs -class TestCreateIndividualFeaturesWithTemplates(absltest.TestCase): +class TestCreateIndividualFeaturesWithTemplates(parameterized.TestCase): def setUp(self): super().setUp() self.temp_dir = tempfile.TemporaryDirectory() # Create a temporary directory self.TEST_DATA_DIR = Path(self.temp_dir.name) # Use the temporary directory as the test data directory - #self.TEST_DATA_DIR = Path(__file__).parent / "test_data" # DELETEME # Copy test data files to the temporary directory original_test_data_dir = Path(__file__).parent / "test_data" shutil.copytree(original_test_data_dir, self.TEST_DATA_DIR, dirs_exist_ok=True) @@ -26,74 +28,47 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() # Clean up the temporary directory - def run_features_generation(self, file_name, chain_id, file_extension, use_mmseqs2): - # Ensure directories exist - (self.TEST_DATA_DIR / 'features').mkdir(parents=True, exist_ok=True) - (self.TEST_DATA_DIR / 'templates').mkdir(parents=True, exist_ok=True) - # Remove existing files (should be done by tearDown, but just in case) - pkl_path = self.TEST_DATA_DIR / 'features' / f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.pkl' - if use_mmseqs2: - sto_or_a3m_path = ( - self.TEST_DATA_DIR / 'features' / - f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}' / - f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.a3m' - ) - else: - sto_or_a3m_path = ( - self.TEST_DATA_DIR / 'features' / - f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}' / - 'pdb_hits.sto' - ) - template_path = self.TEST_DATA_DIR / 'templates' / f'{file_name}.{file_extension}' - if pkl_path.exists(): - pkl_path.unlink() - if sto_or_a3m_path.exists(): - sto_or_a3m_path.unlink() - - # Generate description.csv - with open(f"{self.TEST_DATA_DIR}/description.csv", 'w') as desc_file: - desc_file.write(f">{file_name}_{chain_id}, {file_name}.{file_extension}, {chain_id}\n") + def create_mock_file(self, file_path): + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + Path(file_path).touch(exist_ok=True) - assert Path(f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta").exists() - - def create_mock_file(file_path): - Path(file_path).parent.mkdir(parents=True, exist_ok=True) - Path(file_path).touch(exist_ok=True) - - # Common root directory + def mock_databases(self): root_dir = self.TEST_DATA_DIR - # Mock databases - create_mock_file(root_dir / 'uniref90/uniref90.fasta') - create_mock_file(root_dir / 'mgnify/mgy_clusters_2022_05.fa') - create_mock_file(root_dir / 'uniprot/uniprot.fasta') - create_mock_file(root_dir / 'bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffindex') - create_mock_file(root_dir / 'uniref30/UniRef30_2021_03_hhm.ffindex') - create_mock_file(root_dir / 'uniref30/UniRef30_2023_02_hhm.ffindex') - create_mock_file(root_dir / 'pdb70/pdb70_hhm.ffindex') + self.create_mock_file(root_dir / 'uniref90/uniref90.fasta') + self.create_mock_file(root_dir / 'mgnify/mgy_clusters_2022_05.fa') + self.create_mock_file(root_dir / 'uniprot/uniprot.fasta') + self.create_mock_file(root_dir / 'bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffindex') + self.create_mock_file(root_dir / 'uniref30/UniRef30_2021_03_hhm.ffindex') + self.create_mock_file(root_dir / 'uniref30/UniRef30_2023_02_hhm.ffindex') + self.create_mock_file(root_dir / 'pdb70/pdb70_hhm.ffindex') - # Mock hhblits files hhblits_root = root_dir / 'bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt' - hhblits_files = ['_a3m.ffdata', '_a3m.ffindex', '_cs219.ffdata', '_cs219.ffindex', '_hhmm.ffdata', - '_hhmm.ffindex'] + hhblits_files = ['_a3m.ffdata', '_a3m.ffindex', '_cs219.ffdata', '_cs219.ffindex', '_hhmm.ffdata', '_hhmm.ffindex'] for file in hhblits_files: - create_mock_file(hhblits_root / file) + self.create_mock_file(hhblits_root / file) - # Mock uniclust30 files uniclust_db_root = root_dir / 'uniclust30/uniclust30_2018_08/uniclust30_2018_08' uniclust_db_files = ['_a3m_db', '_a3m.ffdata', '_a3m.ffindex', '.cs219', '_cs219.ffdata', '_cs219.ffindex', '_hhm_db', '_hhm.ffdata', '_hhm.ffindex'] for suffix in uniclust_db_files: - create_mock_file(f"{uniclust_db_root}{suffix}") + self.create_mock_file(f"{uniclust_db_root}{suffix}") - # Mock uniref30 files - Adjusted for the correct naming convention - #uniref_db_root = root_dir / 'uniref30/UniRef30_2021_03' uniref_db_root = root_dir / 'uniref30/UniRef30_2023_02' uniref_db_files = ['_a3m.ffdata', '_a3m.ffindex', '_hmm.ffdata', '_hmm.ffindex', '_cs.ffdata', '_cs.ffindex'] for suffix in uniref_db_files: - create_mock_file(f"{uniref_db_root}{suffix}") + self.create_mock_file(f"{uniref_db_root}{suffix}") + + def run_features_generation(self, file_name, chain_id, file_extension, use_mmseqs2, compress_features=False): + (self.TEST_DATA_DIR / 'features').mkdir(parents=True, exist_ok=True) + (self.TEST_DATA_DIR / 'templates').mkdir(parents=True, exist_ok=True) + self.mock_databases() + + with open(f"{self.TEST_DATA_DIR}/description.csv", 'w') as desc_file: + desc_file.write(f">{file_name}_{chain_id}, {file_name}.{file_extension}, {chain_id}\n") + + assert Path(f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta").exists() - # Prepare the command and arguments cmd = [ 'create_individual_features.py', '--use_precomputed_msas', 'True', @@ -110,26 +85,32 @@ def create_mock_file(file_path): '--output_dir', f"{self.TEST_DATA_DIR}/features", ] if use_mmseqs2: - cmd.extend(['--use_mmseqs2', 'True',]) - print(" ".join(cmd)) - # Check the output + cmd.extend(['--use_mmseqs2', 'True']) + if compress_features: + cmd.extend(['--compress_features', 'True']) subprocess.run(cmd, check=True) - features_dir = self.TEST_DATA_DIR / 'features' - - # List all files in the directory - for file in features_dir.iterdir(): - if file.is_file(): - print(file) - print("pkl path") - print(pkl_path) - assert pkl_path.exists() - assert sto_or_a3m_path.exists() - - with open(pkl_path, 'rb') as f: - feats = pickle.load(f).feature_dict - temp_sequence = feats['template_sequence'][0].decode('utf-8') - target_sequence = feats['sequence'][0].decode('utf-8') - atom_coords = feats['template_all_atom_positions'][0] + + def validate_generated_features(self, pkl_path, json_path, file_name, file_extension, chain_id, compress_features): + # Validate that the expected output files exist + self.assertTrue(json_path.exists(), f"Metadata JSON file was not created: {json_path}") + self.assertTrue(pkl_path.exists(), f"Pickle file was not created: {pkl_path}") + + # Validate the contents of the PKL file + if compress_features: + with lzma.open(pkl_path, 'rb') as f: + monomeric_object = pickle.load(f) + else: + with open(pkl_path, 'rb') as f: + monomeric_object = pickle.load(f) + + self.assertTrue(hasattr(monomeric_object, 'feature_dict'), "Loaded object does not have 'feature_dict' attribute.") + features = monomeric_object.feature_dict + + # Validate that the expected sequences and atom coordinates are present and valid + temp_sequence = features['template_sequence'][0].decode('utf-8') + target_sequence = features['sequence'][0].decode('utf-8') + atom_coords = features['template_all_atom_positions'][0] + template_path = self.TEST_DATA_DIR / 'templates' / f'{file_name}.{file_extension}' # Check that template sequence is not empty assert len(temp_sequence) > 0 # Check that the atom coordinates are not all 0 @@ -168,29 +149,60 @@ def create_mock_file(file_path): #print(residue_has_nonzero_coords) #print(len(residue_has_nonzero_coords)) - def test_1a_run_features_generation(self): - self.run_features_generation('3L4Q', 'A', 'cif', False) - - def test_2c_run_features_generation(self): - self.run_features_generation('3L4Q', 'C', 'pdb', False) - - def test_3b_bizarre_filename(self): - self.run_features_generation('RANdom_name1_.7-1_0', 'B', 'pdb', False) - - def test_4c_bizarre_filename(self): - self.run_features_generation('RANdom_name1_.7-1_0', 'C', 'pdb', False) - - def test_5b_gappy_pdb(self): - self.run_features_generation('GAPPY_PDB', 'B', 'pdb', False) + @parameterized.parameters( + {'compress_features': True, 'file_name': '3L4Q', 'chain_id': 'A', 'file_extension': 'cif'}, + {'compress_features': False, 'file_name': '3L4Q', 'chain_id': 'A', 'file_extension': 'cif'}, + ) + def test_compress_features_flag(self, compress_features, file_name, chain_id, file_extension): + self.run_features_generation(file_name, chain_id, file_extension, use_mmseqs2=False, compress_features=compress_features) + + json_pattern = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}_feature_metadata_*.json' + if compress_features: + json_pattern += '.xz' + metadata_files = list((self.TEST_DATA_DIR / 'features').glob(json_pattern)) + self.assertTrue(len(metadata_files) > 0, "Metadata JSON file was not created.") + json_path = metadata_files[0] + + pkl_filename = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.pkl' + if compress_features: + pkl_filename += '.xz' + pkl_path = self.TEST_DATA_DIR / 'features' / pkl_filename + + self.validate_generated_features(pkl_path, json_path, file_name, file_extension, chain_id, compress_features) + + # Clean up + pkl_path.unlink(missing_ok=True) + json_path.unlink(missing_ok=True) + + @parameterized.parameters( + {'file_name': '3L4Q', 'chain_id': 'A', 'file_extension': 'cif', 'use_mmseqs2': False}, + {'file_name': '3L4Q', 'chain_id': 'C', 'file_extension': 'pdb', 'use_mmseqs2': False}, + {'file_name': 'RANdom_name1_.7-1_0', 'chain_id': 'B', 'file_extension': 'pdb', 'use_mmseqs2': False}, + {'file_name': 'RANdom_name1_.7-1_0', 'chain_id': 'C', 'file_extension': 'pdb', 'use_mmseqs2': False}, + {'file_name': 'GAPPY_PDB', 'chain_id': 'B', 'file_extension': 'pdb', 'use_mmseqs2': False}, + {'file_name': 'hetatoms', 'chain_id': 'A', 'file_extension': 'pdb', 'use_mmseqs2': False}, + ) + def test_run_features_generation(self, file_name, chain_id, file_extension, use_mmseqs2): + self.run_features_generation(file_name, chain_id, file_extension, use_mmseqs2) + + # Determine the output paths for validation + pkl_filename = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.pkl' + pkl_path = self.TEST_DATA_DIR / 'features' / pkl_filename + json_pattern = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}_feature_metadata_*.json' + metadata_files = list((self.TEST_DATA_DIR / 'features').glob(json_pattern)) + self.assertTrue(len(metadata_files) > 0, "Metadata JSON file was not created.") + json_path = metadata_files[0] + + self.validate_generated_features(pkl_path, json_path, file_name, file_extension, chain_id, compress_features=False) + + # Clean up + pkl_path.unlink(missing_ok=True) + json_path.unlink(missing_ok=True) @absltest.skip("use_mmseqs2 must not be set when running with --path_to_mmts") def test_6a_mmseqs2(self): self.run_features_generation('3L4Q', 'A', 'cif', True) - def test_7a_hetatoms(self): - self.run_features_generation('hetatoms', 'A', 'pdb', False) - - if __name__ == '__main__': absltest.main()