Skip to content

Commit

Permalink
Compress features if --compress-features. Added corresponding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaMolod committed Sep 3, 2024
1 parent 7842293 commit e4e5062
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 102 deletions.
31 changes: 24 additions & 7 deletions alphapulldown/scripts/create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# coding: utf-8
# Create features for AlphaFold from fasta file(s) or a csv file with descriptions for multimeric templates
# #

import json
import lzma
import os
import pickle
import sys
Expand Down Expand Up @@ -40,6 +41,9 @@
flags.DEFINE_boolean("use_hhsearch", False,
"Use hhsearch instead of hmmsearch when looking for structure template. Default is False")

flags.DEFINE_boolean("compress_features", False,
"Compress features.pkl and meta.json files using lzma algorithm. Default is False")

# Flags related to TrueMultimer
flags.DEFINE_string("path_to_mmt", None,
"Path to directory with multimeric template mmCIF files")
Expand Down Expand Up @@ -229,10 +233,18 @@ def create_and_save_monomer_objects(monomer, pipeline):
return

# Save metadata
metadata_output_path = os.path.join(FLAGS.output_dir,
f"{monomer.description}_feature_metadata_{datetime.date(datetime.now())}.json")
with save_meta_data.output_meta_file(metadata_output_path) as meta_data_outfile:
save_meta_data.save_meta_data(flags_dict, meta_data_outfile)
meta_dict = save_meta_data.get_meta_dict(flags_dict)
metadata_output_path = os.path.join(
FLAGS.output_dir,
f"{monomer.description}_feature_metadata_{datetime.now().date()}.json"
)

if FLAGS.compress_features:
with lzma.open(metadata_output_path + '.xz', "wt") as meta_data_outfile:
json.dump(meta_dict, meta_data_outfile)
else:
with open(metadata_output_path, "w") as meta_data_outfile:
json.dump(meta_dict, meta_data_outfile)

# Create features
if FLAGS.use_mmseqs2:
Expand All @@ -250,8 +262,13 @@ def create_and_save_monomer_objects(monomer, pipeline):
)

# Save the processed monomer object
with open(pickle_path, "wb") as pickle_file:
pickle.dump(monomer, pickle_file)
if FLAGS.compress_features:
pickle_path = pickle_path + ".xz"
with lzma.open(pickle_path, "wb") as pickle_file:
pickle.dump(monomer, pickle_file)
else:
with open(pickle_path, "wb") as pickle_file:
pickle.dump(monomer, pickle_file)

# Optional: Clear monomer from memory if necessary
del monomer
Expand Down
5 changes: 2 additions & 3 deletions alphapulldown/utils/save_meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_metadata_for_database(k, v):
return {}


def save_meta_data(flag_dict, outfile):
def get_meta_dict(flag_dict):
"""Save metadata in JSON format."""
metadata = {
"databases": {},
Expand Down Expand Up @@ -139,8 +139,7 @@ def save_meta_data(flag_dict, outfile):
"location_url": url}
})

with open(outfile, "w") as f:
json.dump(metadata, f, indent=2)
return metadata


def get_last_modified_date(path):
Expand Down
196 changes: 104 additions & 92 deletions test/test_features_with_templates.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import os
import shutil
import pickle
import tempfile
import subprocess
from pathlib import Path
import lzma
import json

import numpy as np
from absl.testing import absltest
from absl.testing import absltest, parameterized

from alphapulldown.utils.remove_clashes_low_plddt import extract_seqs

class TestCreateIndividualFeaturesWithTemplates(absltest.TestCase):
class TestCreateIndividualFeaturesWithTemplates(parameterized.TestCase):

def setUp(self):
super().setUp()
self.temp_dir = tempfile.TemporaryDirectory() # Create a temporary directory
self.TEST_DATA_DIR = Path(self.temp_dir.name) # Use the temporary directory as the test data directory
#self.TEST_DATA_DIR = Path(__file__).parent / "test_data" # DELETEME
# Copy test data files to the temporary directory
original_test_data_dir = Path(__file__).parent / "test_data"
shutil.copytree(original_test_data_dir, self.TEST_DATA_DIR, dirs_exist_ok=True)
Expand All @@ -26,74 +28,47 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup() # Clean up the temporary directory

def run_features_generation(self, file_name, chain_id, file_extension, use_mmseqs2):
# Ensure directories exist
(self.TEST_DATA_DIR / 'features').mkdir(parents=True, exist_ok=True)
(self.TEST_DATA_DIR / 'templates').mkdir(parents=True, exist_ok=True)
# Remove existing files (should be done by tearDown, but just in case)
pkl_path = self.TEST_DATA_DIR / 'features' / f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.pkl'
if use_mmseqs2:
sto_or_a3m_path = (
self.TEST_DATA_DIR / 'features' /
f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}' /
f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.a3m'
)
else:
sto_or_a3m_path = (
self.TEST_DATA_DIR / 'features' /
f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}' /
'pdb_hits.sto'
)
template_path = self.TEST_DATA_DIR / 'templates' / f'{file_name}.{file_extension}'
if pkl_path.exists():
pkl_path.unlink()
if sto_or_a3m_path.exists():
sto_or_a3m_path.unlink()

# Generate description.csv
with open(f"{self.TEST_DATA_DIR}/description.csv", 'w') as desc_file:
desc_file.write(f">{file_name}_{chain_id}, {file_name}.{file_extension}, {chain_id}\n")
def create_mock_file(self, file_path):
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
Path(file_path).touch(exist_ok=True)

assert Path(f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta").exists()

def create_mock_file(file_path):
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
Path(file_path).touch(exist_ok=True)

# Common root directory
def mock_databases(self):
root_dir = self.TEST_DATA_DIR

# Mock databases
create_mock_file(root_dir / 'uniref90/uniref90.fasta')
create_mock_file(root_dir / 'mgnify/mgy_clusters_2022_05.fa')
create_mock_file(root_dir / 'uniprot/uniprot.fasta')
create_mock_file(root_dir / 'bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffindex')
create_mock_file(root_dir / 'uniref30/UniRef30_2021_03_hhm.ffindex')
create_mock_file(root_dir / 'uniref30/UniRef30_2023_02_hhm.ffindex')
create_mock_file(root_dir / 'pdb70/pdb70_hhm.ffindex')
self.create_mock_file(root_dir / 'uniref90/uniref90.fasta')
self.create_mock_file(root_dir / 'mgnify/mgy_clusters_2022_05.fa')
self.create_mock_file(root_dir / 'uniprot/uniprot.fasta')
self.create_mock_file(root_dir / 'bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffindex')
self.create_mock_file(root_dir / 'uniref30/UniRef30_2021_03_hhm.ffindex')
self.create_mock_file(root_dir / 'uniref30/UniRef30_2023_02_hhm.ffindex')
self.create_mock_file(root_dir / 'pdb70/pdb70_hhm.ffindex')

# Mock hhblits files
hhblits_root = root_dir / 'bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt'
hhblits_files = ['_a3m.ffdata', '_a3m.ffindex', '_cs219.ffdata', '_cs219.ffindex', '_hhmm.ffdata',
'_hhmm.ffindex']
hhblits_files = ['_a3m.ffdata', '_a3m.ffindex', '_cs219.ffdata', '_cs219.ffindex', '_hhmm.ffdata', '_hhmm.ffindex']
for file in hhblits_files:
create_mock_file(hhblits_root / file)
self.create_mock_file(hhblits_root / file)

# Mock uniclust30 files
uniclust_db_root = root_dir / 'uniclust30/uniclust30_2018_08/uniclust30_2018_08'
uniclust_db_files = ['_a3m_db', '_a3m.ffdata', '_a3m.ffindex', '.cs219', '_cs219.ffdata', '_cs219.ffindex',
'_hhm_db', '_hhm.ffdata', '_hhm.ffindex']
for suffix in uniclust_db_files:
create_mock_file(f"{uniclust_db_root}{suffix}")
self.create_mock_file(f"{uniclust_db_root}{suffix}")

# Mock uniref30 files - Adjusted for the correct naming convention
#uniref_db_root = root_dir / 'uniref30/UniRef30_2021_03'
uniref_db_root = root_dir / 'uniref30/UniRef30_2023_02'
uniref_db_files = ['_a3m.ffdata', '_a3m.ffindex', '_hmm.ffdata', '_hmm.ffindex', '_cs.ffdata', '_cs.ffindex']
for suffix in uniref_db_files:
create_mock_file(f"{uniref_db_root}{suffix}")
self.create_mock_file(f"{uniref_db_root}{suffix}")

def run_features_generation(self, file_name, chain_id, file_extension, use_mmseqs2, compress_features=False):
(self.TEST_DATA_DIR / 'features').mkdir(parents=True, exist_ok=True)
(self.TEST_DATA_DIR / 'templates').mkdir(parents=True, exist_ok=True)
self.mock_databases()

with open(f"{self.TEST_DATA_DIR}/description.csv", 'w') as desc_file:
desc_file.write(f">{file_name}_{chain_id}, {file_name}.{file_extension}, {chain_id}\n")

assert Path(f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta").exists()

# Prepare the command and arguments
cmd = [
'create_individual_features.py',
'--use_precomputed_msas', 'True',
Expand All @@ -110,26 +85,32 @@ def create_mock_file(file_path):
'--output_dir', f"{self.TEST_DATA_DIR}/features",
]
if use_mmseqs2:
cmd.extend(['--use_mmseqs2', 'True',])
print(" ".join(cmd))
# Check the output
cmd.extend(['--use_mmseqs2', 'True'])
if compress_features:
cmd.extend(['--compress_features', 'True'])
subprocess.run(cmd, check=True)
features_dir = self.TEST_DATA_DIR / 'features'

# List all files in the directory
for file in features_dir.iterdir():
if file.is_file():
print(file)
print("pkl path")
print(pkl_path)
assert pkl_path.exists()
assert sto_or_a3m_path.exists()

with open(pkl_path, 'rb') as f:
feats = pickle.load(f).feature_dict
temp_sequence = feats['template_sequence'][0].decode('utf-8')
target_sequence = feats['sequence'][0].decode('utf-8')
atom_coords = feats['template_all_atom_positions'][0]

def validate_generated_features(self, pkl_path, json_path, file_name, file_extension, chain_id, compress_features):
# Validate that the expected output files exist
self.assertTrue(json_path.exists(), f"Metadata JSON file was not created: {json_path}")
self.assertTrue(pkl_path.exists(), f"Pickle file was not created: {pkl_path}")

# Validate the contents of the PKL file
if compress_features:
with lzma.open(pkl_path, 'rb') as f:
monomeric_object = pickle.load(f)
else:
with open(pkl_path, 'rb') as f:
monomeric_object = pickle.load(f)

self.assertTrue(hasattr(monomeric_object, 'feature_dict'), "Loaded object does not have 'feature_dict' attribute.")
features = monomeric_object.feature_dict

# Validate that the expected sequences and atom coordinates are present and valid
temp_sequence = features['template_sequence'][0].decode('utf-8')
target_sequence = features['sequence'][0].decode('utf-8')
atom_coords = features['template_all_atom_positions'][0]
template_path = self.TEST_DATA_DIR / 'templates' / f'{file_name}.{file_extension}'
# Check that template sequence is not empty
assert len(temp_sequence) > 0
# Check that the atom coordinates are not all 0
Expand Down Expand Up @@ -168,29 +149,60 @@ def create_mock_file(file_path):
#print(residue_has_nonzero_coords)
#print(len(residue_has_nonzero_coords))

def test_1a_run_features_generation(self):
self.run_features_generation('3L4Q', 'A', 'cif', False)

def test_2c_run_features_generation(self):
self.run_features_generation('3L4Q', 'C', 'pdb', False)

def test_3b_bizarre_filename(self):
self.run_features_generation('RANdom_name1_.7-1_0', 'B', 'pdb', False)

def test_4c_bizarre_filename(self):
self.run_features_generation('RANdom_name1_.7-1_0', 'C', 'pdb', False)

def test_5b_gappy_pdb(self):
self.run_features_generation('GAPPY_PDB', 'B', 'pdb', False)
@parameterized.parameters(
{'compress_features': True, 'file_name': '3L4Q', 'chain_id': 'A', 'file_extension': 'cif'},
{'compress_features': False, 'file_name': '3L4Q', 'chain_id': 'A', 'file_extension': 'cif'},
)
def test_compress_features_flag(self, compress_features, file_name, chain_id, file_extension):
self.run_features_generation(file_name, chain_id, file_extension, use_mmseqs2=False, compress_features=compress_features)

json_pattern = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}_feature_metadata_*.json'
if compress_features:
json_pattern += '.xz'
metadata_files = list((self.TEST_DATA_DIR / 'features').glob(json_pattern))
self.assertTrue(len(metadata_files) > 0, "Metadata JSON file was not created.")
json_path = metadata_files[0]

pkl_filename = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.pkl'
if compress_features:
pkl_filename += '.xz'
pkl_path = self.TEST_DATA_DIR / 'features' / pkl_filename

self.validate_generated_features(pkl_path, json_path, file_name, file_extension, chain_id, compress_features)

# Clean up
pkl_path.unlink(missing_ok=True)
json_path.unlink(missing_ok=True)

@parameterized.parameters(
{'file_name': '3L4Q', 'chain_id': 'A', 'file_extension': 'cif', 'use_mmseqs2': False},
{'file_name': '3L4Q', 'chain_id': 'C', 'file_extension': 'pdb', 'use_mmseqs2': False},
{'file_name': 'RANdom_name1_.7-1_0', 'chain_id': 'B', 'file_extension': 'pdb', 'use_mmseqs2': False},
{'file_name': 'RANdom_name1_.7-1_0', 'chain_id': 'C', 'file_extension': 'pdb', 'use_mmseqs2': False},
{'file_name': 'GAPPY_PDB', 'chain_id': 'B', 'file_extension': 'pdb', 'use_mmseqs2': False},
{'file_name': 'hetatoms', 'chain_id': 'A', 'file_extension': 'pdb', 'use_mmseqs2': False},
)
def test_run_features_generation(self, file_name, chain_id, file_extension, use_mmseqs2):
self.run_features_generation(file_name, chain_id, file_extension, use_mmseqs2)

# Determine the output paths for validation
pkl_filename = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}.pkl'
pkl_path = self.TEST_DATA_DIR / 'features' / pkl_filename
json_pattern = f'{file_name}_{chain_id}.{file_name}.{file_extension}.{chain_id}_feature_metadata_*.json'
metadata_files = list((self.TEST_DATA_DIR / 'features').glob(json_pattern))
self.assertTrue(len(metadata_files) > 0, "Metadata JSON file was not created.")
json_path = metadata_files[0]

self.validate_generated_features(pkl_path, json_path, file_name, file_extension, chain_id, compress_features=False)

# Clean up
pkl_path.unlink(missing_ok=True)
json_path.unlink(missing_ok=True)

@absltest.skip("use_mmseqs2 must not be set when running with --path_to_mmts")
def test_6a_mmseqs2(self):
self.run_features_generation('3L4Q', 'A', 'cif', True)

def test_7a_hetatoms(self):
self.run_features_generation('hetatoms', 'A', 'pdb', False)



if __name__ == '__main__':
absltest.main()

0 comments on commit e4e5062

Please sign in to comment.