diff --git a/.gitignore b/.gitignore index 2134488..6cb2925 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ dev_data/output_old/* !dev_data/output/dev_comp_domain_fit_3_domains_10s20q_new_q case_study_data run +/dev_data/input/domain_fit_demo_3domains/subunits_cif_test +*.npz diff --git a/Reproduce.md b/Reproduce.md index 2de9157..92d5bff 100644 --- a/Reproduce.md +++ b/Reproduce.md @@ -13,7 +13,9 @@ This guide is to reproduce Fig. 4 of our paper. 1. Download [DiffFit-0.6.2-py3-none-any.whl](https://github.com/nanovis/DiffFit/releases/download/v0.6.2/DiffFit-0.6.2-py3-none-any.whl) 2. Follow the [installation guide](https://github.com/nanovis/DiffFit?tab=readme-ov-file#install) -3. Download the [repository](https://github.com/nanovis/DiffFit/archive/refs/heads/main.zip) and unzip it to Desktop +3. Download the [repository](https://github.com/nanovis/DiffFit/archive/refs/heads/main.zip) + and unzip it to your working directory + (which is your Desktop if you are on Windows or your home folder if you are on Linux). 4. Open ChimeraX, launch DiffFit, go to the `Settings` tab, change `Fit atoms:` to `All atoms`. 5. Go to the `Disk` tab, set the parameters as below: 1. Target Volume: `DiffFit-main\dev_data\input\domain_fit_demo_3domains\density2.mrc` diff --git a/bundle_info.xml b/bundle_info.xml index d80160b..f09ef4a 100644 --- a/bundle_info.xml +++ b/bundle_info.xml @@ -30,6 +30,7 @@ will be displayed without the ChimeraX- prefix. + @@ -48,6 +49,14 @@ will be displayed without the ChimeraX- prefix. ChimeraX :: Tool :: DiffFit :: Volume Data :: Differentiable Fit in Map + + ChimeraX :: Command :: dfit :: General :: + Fit a single structure into a volume map + ChimeraX :: Command :: dfit disk :: General :: + Fit a list of structures from disk into a volume map + ChimeraX :: Command :: dfit path :: General :: + Manipulate the paths in DiffFit results + \ No newline at end of file diff --git a/dev_data/input/domain_fit_demo_3domains/subunits_mrc/I7LV70_D1.mrc b/dev_data/input/domain_fit_demo_3domains/subunits_cif/I7LV70_D1.mrc similarity index 100% rename from dev_data/input/domain_fit_demo_3domains/subunits_mrc/I7LV70_D1.mrc rename to dev_data/input/domain_fit_demo_3domains/subunits_cif/I7LV70_D1.mrc diff --git a/dev_data/input/domain_fit_demo_3domains/subunits_mrc/I7M317_D1.mrc b/dev_data/input/domain_fit_demo_3domains/subunits_cif/I7M317_D1.mrc similarity index 100% rename from dev_data/input/domain_fit_demo_3domains/subunits_mrc/I7M317_D1.mrc rename to dev_data/input/domain_fit_demo_3domains/subunits_cif/I7M317_D1.mrc diff --git a/dev_data/input/domain_fit_demo_3domains/subunits_mrc/I7MLV6_D3.mrc b/dev_data/input/domain_fit_demo_3domains/subunits_cif/I7MLV6_D3.mrc similarity index 100% rename from dev_data/input/domain_fit_demo_3domains/subunits_mrc/I7MLV6_D3.mrc rename to dev_data/input/domain_fit_demo_3domains/subunits_cif/I7MLV6_D3.mrc diff --git a/src/DiffAtomComp.py b/src/DiffAtomComp.py index cb37fb0..4052eb1 100644 --- a/src/DiffAtomComp.py +++ b/src/DiffAtomComp.py @@ -1,5 +1,6 @@ import argparse import os, sys +from pathlib import Path from datetime import datetime import numpy as np @@ -27,6 +28,33 @@ # Ignore PDBConstructionWarning for unrecognized 'END' record warnings.filterwarnings("ignore", message="Ignoring unrecognized record 'END'", category=PDBConstructionWarning) +from chimerax.geometry.bins import Binned_Transforms + + +class DiffFit_Binned_Transforms(Binned_Transforms): + def __init__(self, angle, translation, center=(0, 0, 0), bfactor=2): + super().__init__(angle, translation, center, bfactor) + + def one_in_cluster_transform(self, tf): + + a, x, y, z = c = self.bin_point(tf) + clist = self.bins.close_objects(c, self.spacing) + if len(clist) == 0: + return None + + itf = tf.inverse() + d2max = self.translation * self.translation + for ctf in clist: + cx, cy, cz = ctf * self.center + dx, dy, dz = x - cx, y - cy, z - cz + d2 = dx * dx + dy * dy + dz * dz + if d2 <= d2max: + dtf = ctf * itf + a = dtf.rotation_angle() + if a < self.angle: + return ctf + + return None def interpolate_coords(coords, inter_folds, inter_kind='quadratic'): """Interpolate backbone coordinates.""" @@ -95,8 +123,13 @@ def q2_unit_coord(Q): return np.concatenate((rotated_up, rotated_right), axis=-1) -def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3.0, angle_tolerance: float = 6.0, - sort_column_idx: int = 7, in_contour_threshold = 0.5, correlation_threshold = 0.5): +def cluster_and_sort_sqd_fast(e_sqd_log, shift_tolerance: float = 3.0, angle_tolerance: float = 6.0, + sort_column_idx: int = 7, + in_contour_threshold: float = 0.5, + correlation_threshold: float = 0.5, + df_cid_threshold: float = 0.15, + save_log=False, + log_path=""): """ Cluster the fitting results in sqd table by thresholding on shift and quaternion Return the sorted cluster representatives @@ -113,13 +146,13 @@ def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3 3. sort the cluster table in descending order by correlation, or the metric at the sort_column_idx @param e_sqd_log: fitting results in sqd table - @param mol_centers: molecule atom coords centers @param shift_tolerance: shift tolerance in Angstrom @param angle_tolerance: angle tolerance in degrees @param sort_column_idx: the column to sort, 9-th column is the correlation @return: cluster representative table sorted in descending order """ - from chimerax.geometry import bins + mol_center = np.array([0.0, 0.0, 0.0]) + from chimerax.geometry import Place N_mol, N_record, N_iter, N_metric = e_sqd_log.shape @@ -137,23 +170,31 @@ def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3 # Use the generated meshgrid and max_sort_column_metric_idx to index into e_sqd_log sqd_highest_corr_np = e_sqd_log[dims_0, dims_1, max_sort_column_metric_idx] + timer_start = datetime.now() + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"DiffFit fit_res filtering starts: {timer_start}\n") + fit_res_filtered = [] fit_res_filtered_indices = [] in_contour_col_idx = 11 correlation_col_idx = 9 + df_cid_col_idx = 14 for mol_idx in range(N_mol): sqd_highest_corr_np_mol = sqd_highest_corr_np[mol_idx] # Fetch the columns of interest in_contour_percentage_column = sqd_highest_corr_np_mol[:, in_contour_col_idx] correlation_column = sqd_highest_corr_np_mol[:, correlation_col_idx] + df_cid_column = sqd_highest_corr_np_mol[:, df_cid_col_idx] # Create masks for the filtering conditions in_contour_mask = in_contour_percentage_column >= in_contour_threshold correlation_mask = correlation_column >= correlation_threshold + df_cid_mask = df_cid_column >= df_cid_threshold # Combine the masks to get a final filter - combined_mask = in_contour_mask & correlation_mask + combined_mask = in_contour_mask & correlation_mask & df_cid_mask # Apply the mask to filter the original array and also retrieve the indices filtered_indices = np.where(combined_mask) # Get the indices of the filtered rows @@ -162,12 +203,21 @@ def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3 fit_res_filtered.append(filtered_array) fit_res_filtered_indices.append(filtered_indices[0]) + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"DiffFit fit_res filtering time elapsed: {datetime.now() - timer_start}\n" + f"-------\n") sqd_clusters = [] for mol_idx in range(N_mol): mol_shift = fit_res_filtered[mol_idx][:, :3] mol_q = fit_res_filtered[mol_idx][:, 3:7] + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"Clustering {len(mol_shift)} fits for mol_idx: {mol_idx}\n") + timer_start = datetime.now() + T = [] for i in range(len(mol_shift)): shift = mol_shift[i] @@ -181,21 +231,34 @@ def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3 transformation = Place(matrix=T_matrix) T.append(transformation) - b = bins.Binned_Transforms(angle_tolerance * pi / 180, shift_tolerance, mol_centers[mol_idx]) + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"Convert to matrix time: {datetime.now() - timer_start}\n") + timer_start = datetime.now() + + b = DiffFit_Binned_Transforms(angle_tolerance * pi / 180, shift_tolerance, mol_center) mol_transform_label = [] unique_id = 0 T_ID_dict = {} for i in range(len(mol_shift)): ptf = T[i] - close = b.close_transforms(ptf) - if len(close) == 0: + in_cluster = b.one_in_cluster_transform(ptf) + if in_cluster is None: b.add_transform(ptf) mol_transform_label.append(unique_id) T_ID_dict[id(ptf)] = unique_id unique_id = unique_id + 1 else: - mol_transform_label.append(T_ID_dict[id(close[0])]) - T_ID_dict[id(ptf)] = T_ID_dict[id(close[0])] + mol_transform_label.append(T_ID_dict[id(in_cluster)]) + T_ID_dict[id(ptf)] = T_ID_dict[id(in_cluster)] + + if save_log and (i + 1) % 10000 == 0: + with open(log_path, "a") as log_file: + log_file.write(f"Clustered {i+1} fits: {datetime.now()}\n") + + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"ChimeraX bin clustering: {datetime.now() - timer_start}\n") unique_labels, indices, counts = np.unique(mol_transform_label, axis=0, return_inverse=True, return_counts=True) @@ -366,12 +429,14 @@ def mrc_to_npy(mrc_filename): def mrc_folder_to_npy_list(mrc_folder): sim_map_list = [] - for file_name in os.listdir(mrc_folder): + for file_name in sorted(os.listdir(mrc_folder)): full_path = os.path.join(mrc_folder, file_name) # Check if the current path is a file and not a directory if os.path.isfile(full_path): - data, steps, origin = mrc_to_npy(full_path) - sim_map_list.append((data, steps, origin)) + file_extension = Path(file_name).suffix.lower() + if file_extension in ['.mrc', '.map']: + data, steps, origin = mrc_to_npy(full_path) + sim_map_list.append((data, steps, origin)) return sim_map_list @@ -435,18 +500,16 @@ def transform_to_angstrom_space(ndc_shift, box_size, box_origin, atom_center_in_ def read_file_and_get_coordinates(file_path, fit_atom_mode="Backbone"): # Determine file extension - file_extension = os.path.splitext(file_path)[1].lower() + file_extension = Path(file_path).suffix.lower() # Initialize parser based on file extension if file_extension == '.cif': parser = MMCIFParser() elif file_extension == '.pdb': parser = PDBParser() - else: - raise ValueError("Unsupported file format. Please provide a .mmcif or .pdb file.") # Parse the structure - structure_id = os.path.basename(file_path).split('.')[0] # Use file name as structure ID + structure_id = Path(file_path).stem # Use file name as structure ID structure = parser.get_structure(structure_id, file_path) # Initialize a list to hold all atom coordinates @@ -630,10 +693,11 @@ def center_atom_coords_list(atom_coords_list, mol_centers): def read_all_files_to_atom_coords_list(structures_dir, fit_atom_mode="Backbone"): atom_coords_list = [] # List all files in the given directory - for file_name in os.listdir(structures_dir): + for file_name in sorted(os.listdir(structures_dir)): full_path = os.path.join(structures_dir, file_name) - # Check if the current path is a file and not a directory - if os.path.isfile(full_path): + + file_extension = Path(file_name).suffix.lower() + if file_extension in ['.pdb', '.cif']: # Read the atom coordinates from the file atom_coords = read_file_and_get_coordinates(full_path, fit_atom_mode) # Append the coordinates to the list @@ -658,13 +722,13 @@ def rotate_centers(mol_centers, e_quaternions): def calculate_metrics(render, elements_sim_density): # Mask to filter elements in render that are greater than zero - mask = render > 0 + mask = (render > 0.0).float() # Apply the mask to the render and elements_sim_density tensors render_filtered = render * mask elements_sim_density_filtered = elements_sim_density * mask - mask_sum = mask.float().sum(dim=-1, keepdim=True) - in_contour_percentage = mask.float().mean(dim=-1) + mask_sum = mask.sum(dim=-1, keepdim=True) + in_contour_percentage = mask.mean(dim=-1) # Calculation of correlation # First, normalize the inputs to have zero mean and unit variance, as Pearson's correlation requires @@ -691,7 +755,29 @@ def calculate_metrics(render, elements_sim_density): correlation = overlap / (render_norm * elements_sim_density_norm) - return torch.nan_to_num(torch.stack((overlap_mean, correlation, cam, in_contour_percentage), dim=-1)) + average_density_inside = render_filtered.sum(dim=-1) / mask.sum(dim=-1) + average_density_all = render.mean(dim=-1) + + weight_c = 1.0/3.0 + weight_i = 1.0/3.0 + weight_d = 1.0/3.0 + + good_correlation = 0.85 + good_in = 0.3 + good_average_density_all = -0.1 + + df_cid = (weight_c * (correlation - good_correlation) / (1.0 - good_correlation) + + weight_i * (in_contour_percentage - good_in) / (1.0 - good_in) + + weight_d * (average_density_all - good_average_density_all) / (0.5 - good_average_density_all)) + + return torch.nan_to_num(torch.stack(( + overlap_mean, + correlation, + cam, + in_contour_percentage, + average_density_inside, + average_density_all, + df_cid), dim=-1)) def diff_fit(volume_list: list, @@ -753,7 +839,7 @@ def diff_fit(volume_list: list, # ======= get atom coords atom_coords_list = mol_coords # atom coords as [x, y, z] - mol_centers = [np.mean(coords, axis=0) for coords in atom_coords_list] + mol_num_atoms = [len(coords) for coords in atom_coords_list] num_molecules = len(atom_coords_list) # read simulated map @@ -786,7 +872,7 @@ def diff_fit(volume_list: list, # Training loop log_every = 10 - e_sqd_log = torch.zeros([num_molecules, N_quaternions, N_shifts, int(n_iters / 10) + 2, 12], device=device) + e_sqd_log = torch.zeros([num_molecules, N_quaternions, N_shifts, int(n_iters / 10) + 2, 15], device=device) # [x, y, z, w, -x, -y, -z, occupied_density_sum] with torch.no_grad(): @@ -807,7 +893,7 @@ def diff_fit(volume_list: list, first_layer_positive_density_sum = torch.zeros([num_molecules, N_quaternions, N_shifts], device=device) occupied_density_sum = torch.zeros([num_molecules, N_quaternions, N_shifts], device=device) - metrics_table = torch.zeros([num_molecules, N_quaternions, N_shifts, 4], device=device) + metrics_table = torch.zeros([num_molecules, N_quaternions, N_shifts, 7], device=device) for mol_idx in range(num_molecules): grid = transform_coords(atom_coords_list[mol_idx], @@ -841,7 +927,7 @@ def diff_fit(volume_list: list, e_sqd_log[:, :, :, log_idx, 0:3] = e_shifts e_sqd_log[:, :, :, log_idx, 3:7] = e_quaternions e_sqd_log[:, :, :, log_idx, 7] = first_layer_positive_density_sum - e_sqd_log[:, :, :, log_idx, 8:12] = metrics_table + e_sqd_log[:, :, :, log_idx, 8:15] = metrics_table if save_results: with open(f"{out_dir}/log.log", "a") as log_file: @@ -869,7 +955,7 @@ def diff_fit(volume_list: list, target_vol_path=vol_path, target_surface_threshold=target_surface_threshold, mol_paths=[mol_path], - mol_centers=mol_centers, + mol_num_atoms=mol_num_atoms, opt_res=e_sqd_log_np) # np.save(f"{out_dir}/sampled_coords.npy", sampled_coords) @@ -883,15 +969,14 @@ def diff_fit(volume_list: list, return (vol_path, target_surface_threshold, [mol_path], - mol_centers, + mol_num_atoms, e_sqd_log_np) def diff_atom_comp(target_vol_path: str, target_surface_threshold: float, - min_cluster_size: float, + min_cluster_size: float, # not in use, use 100 as a placeholder structures_dir: str, - structures_sim_map_dir: str, fit_atom_mode:str = "Backbone", Gaussian_mode:str = "Gaussian with negative (shrink)", N_shifts: int = 10, @@ -950,14 +1035,14 @@ def diff_atom_comp(target_vol_path: str, num_molecules = len(atom_coords_list) # read simulated map - sim_map_list = mrc_folder_to_npy_list(structures_sim_map_dir) + sim_map_list = mrc_folder_to_npy_list(structures_dir) elements_sim_density_list = sample_sim_map(atom_coords_list, sim_map_list, num_molecules, device) # center the mol mol_centers = [np.mean(coords, axis=0) for coords in atom_coords_list] atom_coords_list = center_atom_coords_list(atom_coords_list, mol_centers) - # re-calculate the centers, should be all near zero - mol_centers = [np.mean(coords, axis=0) for coords in atom_coords_list] + + mol_num_atoms = [len(coords) for coords in atom_coords_list] # ======= optimization @@ -987,7 +1072,7 @@ def diff_atom_comp(target_vol_path: str, # Training loop log_every = 10 - e_sqd_log = torch.zeros([num_molecules, N_quaternions, N_shifts, int(n_iters / 10) + 2, 12], device=device) + e_sqd_log = torch.zeros([num_molecules, N_quaternions, N_shifts, int(n_iters / 10) + 2, 15], device=device) # [x, y, z, w, -x, -y, -z, occupied_density_sum] with torch.no_grad(): @@ -1008,7 +1093,7 @@ def diff_atom_comp(target_vol_path: str, first_layer_positive_density_sum = torch.zeros([num_molecules, N_quaternions, N_shifts], device=device) occupied_density_sum = torch.zeros([num_molecules, N_quaternions, N_shifts], device=device) - metrics_table = torch.zeros([num_molecules, N_quaternions, N_shifts, 4], device=device) + metrics_table = torch.zeros([num_molecules, N_quaternions, N_shifts, 7], device=device) for mol_idx in range(num_molecules): grid = transform_coords(atom_coords_list[mol_idx], @@ -1042,7 +1127,7 @@ def diff_atom_comp(target_vol_path: str, e_sqd_log[:, :, :, log_idx, 0:3] = e_shifts e_sqd_log[:, :, :, log_idx, 3:7] = e_quaternions e_sqd_log[:, :, :, log_idx, 7] = first_layer_positive_density_sum - e_sqd_log[:, :, :, log_idx, 8:12] = metrics_table + e_sqd_log[:, :, :, log_idx, 8:15] = metrics_table with open(f"{out_dir}/log.log", "a") as log_file: log_file.write(f"Epoch: {epoch + 1:05d}, " @@ -1063,15 +1148,17 @@ def diff_atom_comp(target_vol_path: str, e_sqd_log[:, :, :, :, 3:7] /= q_norms mol_paths = [] - for file_name in os.listdir(structures_dir): + for file_name in sorted(os.listdir(structures_dir)): full_path = os.path.join(structures_dir, file_name) - mol_paths.append(full_path) + file_extension = Path(file_name).suffix.lower() + if file_extension in ['.pdb', '.cif']: + mol_paths.append(full_path) np.savez_compressed(f"{out_dir}/fit_res.npz", target_vol_path=target_vol_path, target_surface_threshold=target_surface_threshold, mol_paths=mol_paths, - mol_centers=mol_centers, + mol_num_atoms=mol_num_atoms, opt_res=e_sqd_log.detach().cpu().numpy()) # np.save(f"{out_dir}/sampled_coords.npy", sampled_coords) @@ -1085,7 +1172,7 @@ def diff_atom_comp(target_vol_path: str, return (target_vol_path, target_surface_threshold, mol_paths, - mol_centers, + mol_num_atoms, e_sqd_log) @@ -1114,8 +1201,6 @@ def parse_ints(arg): parser.add_argument('--structures_dir', type=str, help="directory containing the structures to be fit") - parser.add_argument('--structures_sim_map_dir', type=str, - help="directory containing the simulated map from the structures to be fit") parser.add_argument('--out_dir', type=str, default="out", help="Output directory") @@ -1142,7 +1227,6 @@ def parse_ints(arg): args.target_surface_threshold, args.min_cluster_size, args.structures_dir, - args.structures_sim_map_dir, out_dir=args.out_dir, out_dir_exist_ok=args.out_dir_exist_ok, N_shifts=args.N_shifts, diff --git a/src/__init__.py b/src/__init__.py index 91bf11c..03253b2 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -38,5 +38,38 @@ def get_class(class_name): return tool.DiffFitTool raise ValueError("Unknown class name '%s'" % class_name) + @staticmethod + def register_command(bi, ci, logger): + # bi is an instance of chimerax.core.toolshed.BundleInfo + # ci is an instance of chimerax.core.toolshed.CommandInfo + # logger is an instance of chimerax.core.logger.Logger + + # We check the name of the command, which should match + # one of the ones listed in bundle_info.xml + # (without the leading and trailing whitespace), + # and import the function to call and its argument + # description from the ``cmd`` module. + # If the description does not contain a synopsis, we + # add the one in ``ci``, which comes from bundle_info.xml. + from . import dfit_cmd + if ci.name == "dfit": + func = dfit_cmd.dfit + desc = dfit_cmd.dfit_desc + elif ci.name == "dfit disk": + func = dfit_cmd.dfit_disk + desc = dfit_cmd.dfit_disk_desc + elif ci.name == "dfit path": + func = dfit_cmd.dfit_path + desc = dfit_cmd.dfit_path_desc + else: + raise ValueError("trying to register unknown command: %s" % ci.name) + if desc.synopsis is None: + desc.synopsis = ci.synopsis + + # We then register the function as the command callback + # with the chimerax.core.commands module. + from chimerax.core.commands import register + register(ci.name, desc, func) + # Create the ``bundle_api`` object that ChimeraX expects. bundle_api = _MyAPI() \ No newline at end of file diff --git a/src/convert2mrc_npy.py b/src/convert2mrc_npy.py index b448fd1..5e5f6ea 100644 --- a/src/convert2mrc_npy.py +++ b/src/convert2mrc_npy.py @@ -11,6 +11,7 @@ from chimerax.core.commands import run from datetime import datetime import numpy as np +from pathlib import Path timer_start = datetime.now() @@ -21,7 +22,7 @@ resolution = float(sys.argv[4]) # Resolution for simulated MRC files gridSpacing = float(sys.argv[5]) # gridSpacing for simulated MRC files -for file_name in os.listdir(structures_dir): +for file_name in sorted(os.listdir(structures_dir)): print(f"\n======= Processing {file_name} =======") full_path = os.path.join(structures_dir, file_name) # Check if the current path is a file and not a directory @@ -30,7 +31,7 @@ structure = run(session, f'open {full_path}')[0] # Get the base name of the input structure for naming output files - structure_basename = os.path.basename(full_path).split('.')[0] + structure_basename = Path(full_path).stem # Save the structure's coordinates as a npy file npy_filename = f"{structure_basename}.npy" diff --git a/src/dfit_cmd.py b/src/dfit_cmd.py new file mode 100644 index 0000000..3e7300d --- /dev/null +++ b/src/dfit_cmd.py @@ -0,0 +1,493 @@ +# vim: set expandtab shiftwidth=4 softtabstop=4: + +from chimerax.core.commands import CmdDesc # Command description +from chimerax.atomic import StructureArg, AtomsArg # Collection of atoms argument +from chimerax.core.commands import (BoolArg, ColorArg, IntArg, FloatArg, StringArg, + SaveFolderNameArg, OpenFileNameArg, SaveFileNameArg, + OpenFolderNameArg, SaveFolderNameArg) +from chimerax.core.commands import EmptyArg # (see below) +from chimerax.core.commands import Or, Bounded # Argument modifiers +from chimerax.map.mapargs import MapArg + +import os +import numpy as np +import torch +from datetime import datetime + +from .DiffAtomComp import diff_atom_comp, cluster_and_sort_sqd_fast, diff_fit, conv_volume, numpy2tensor, \ + linear_norm_tensor +import ast +from chimerax.core.commands import run + +# ========================================================================== +# Functions and descriptions for registering using ChimeraX bundle API +# ========================================================================== + + +def dfit(session, mol, in_map, + level=None, + sim_res=5.0, + num_positions=10, + num_rotations=100, + smooth_by="PyTorch iterative Gaussian", + smooth_loops=3, + kernel_sizes="[5, 5, 5]", + Gaussian_mode="Gaussian with negative (shrink)", + fit_atom_mode="Backbone", + out_dir="DiffFit_out/interactive", + device=None): + """Fit a single structure into a volume map""" + + _save_results = True + _out_dir_exist_ok = True + _out_dir = out_dir + + _use_level = in_map.maximum_surface_level + if level is not None: + _use_level = level + + _use_device = "cpu" + if torch.cuda.is_available(): + _use_device = "cuda:0" + if device is not None: + _use_device = device + + if _save_results: + + os.makedirs(_out_dir, exist_ok=_out_dir_exist_ok) + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"=======\n" + f"Wall clock time: {datetime.now()}\n" + f"-------\n" + f"Interactive mode\n" + f"Target Volume: {in_map.path}\n" + f"Structure: {mol.filename}\n" + f"Target Surface Threshold: {_use_level}\n" + f"-------\n" + f"Sim-map resolution: {sim_res}\n" + f"# positions: {num_positions}\n" + f"# rotations: {num_rotations}\n" + f"Smooth by: \"{smooth_by}\"\n" + f"Smooth loops: {smooth_loops}\n" + f"Kernel sizes: \"{kernel_sizes}\"\n" + f"Gaussian mode: \"{Gaussian_mode}\"\n" + f"Fit atom mode: \"{fit_atom_mode}\"\n" + f"Device: \"{_use_device}\"\n" + f"-------\n") + + from chimerax.core import tools + from .tool import DiffFitTool + df = tools.get_singleton(session, DiffFitTool, 'DiffFit', create=True) + + df.disable_spheres_clicked() + df.fit_mol_list = [mol] + df.fit_vol = in_map + df.mol = mol + + single_fit_timer_start = datetime.now() + + # Prepare mol and vol + vol_matrix = in_map.full_matrix() + + # Copy vol and make it clean after thresholding + vol_copy = in_map.writable_copy() + vol_copy_matrix = vol_copy.data.matrix() + vol_copy_matrix[vol_copy_matrix < _use_level] = 0 + vol_copy.data.values_changed() + + # Smooth the volume + volume_conv_list = _create_volume_conv_list(session, + vol_copy, + smooth_by, smooth_loops, kernel_sizes, Gaussian_mode, + _use_device) + vol_copy.delete() + + # Apply the user's transformation and center mol + from chimerax.geometry import Place + mol_center = mol.atoms.coords.mean(axis=0) + transform = Place(origin=-mol_center) + mol.atoms.transform(transform) + mol.position = Place() + + # Simulate a map for the mol + from chimerax.map.molmap import molecule_map + mol_vol = molecule_map(session, mol.atoms, sim_res, + grid_spacing=in_map.data.step[0]) + + input_coords = None + if fit_atom_mode == "Backbone": + backbone_atoms = ['N', 'CA', 'C', 'O'] + is_backbone = np.isin(mol.atoms.names, backbone_atoms) + + input_coords = mol.atoms.scene_coords[is_backbone] + elif fit_atom_mode == "All": + input_coords = mol.atoms.scene_coords + + + # ======= Generate q_shells + + + # Fit + timer_start = datetime.now() + + if _save_results: + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"DiffFit optimization starts: {timer_start}\n") + + (_, + _, + df.mol_paths, + df.mol_centers, + df.fit_result) = diff_fit( + volume_conv_list, + in_map.path, + _use_level, + in_map.data.step, + in_map.data.origin, + 10, + [input_coords], + mol.filename, + [(mol_vol.full_matrix(), mol_vol.data.step, mol_vol.data.origin)], + N_shifts=num_positions, + N_quaternions=num_rotations, + save_results=_save_results, + out_dir=_out_dir, + out_dir_exist_ok=_out_dir_exist_ok, + device=_use_device + ) + timer_stop = datetime.now() + print(f"\nDiffFit optimization time elapsed: {timer_stop - timer_start}\n\n") + + if _save_results: + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"-------\n" + f"DiffFit optimization time elapsed: {timer_stop - timer_start}\n") + + mol_vol.delete() + + df._view_input_mode.setCurrentText("interactive") + df._view_input_mode_changed() + df.interactive_fit_result_ready = True + df.show_results(df.fit_result, df.mol_centers, df.mol_paths) + + df.tab_widget.setCurrentWidget(df.tab_view_group) + + df.select_table_item(0) + run(session, "view orient") + + timer_stop = datetime.now() + print(f"\nDiffFit total time elapsed: {timer_stop - single_fit_timer_start}\n\n") + + if _save_results: + metric_json = df.return_cluster_metric_json(0) + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"DiffFit total time elapsed: {timer_stop - single_fit_timer_start}\n" + f"-------\n" + f"DiffFit top fit metric:\n" + f"{metric_json}\n" + f"=======\n\n") + + +dfit_desc = CmdDesc(required=[("mol", StructureArg)], + keyword=[("in_map", MapArg), + ("level", FloatArg), + ("sim_res", FloatArg), + ("num_positions", IntArg), + ("num_rotations", IntArg), + ("smooth_by", StringArg), + ("smooth_loops", StringArg), + ("kernel_sizes", StringArg), + ("Gaussian_mode", StringArg), + ("fit_atom_mode", StringArg), + ("out_dir", SaveFolderNameArg), + ("device", StringArg)], + required_arguments=["in_map"]) + +# Example commands +# dfit #1 in #2 +# dfit #1 in #2 +# level 0.7 sim_res 5.0 +# num_p 10 num_r 100 +# smooth_by smooth_loops kernel_sizes gaussian_mode +# fit_atom_mode out_dir device + + +def dfit_disk(session, str_dir, in_map, level, + + num_positions=10, + num_rotations=100, + + smooth_loops=3, + kernel_sizes="[5, 5, 5]", + smooth_weights="[1.0, 1.0, 1.0]", + + Gaussian_mode="Gaussian with negative (shrink)", + fit_atom_mode="Backbone", + negative_space=-0.5, + + learning_rate=0.01, + n_iters=201, + + out_dir="DiffFit_out/disk", + device=None): + """Fit a list of structures from disk into a volume map""" + + from chimerax.core import tools + from .tool import DiffFitTool + df = tools.get_singleton(session, DiffFitTool, 'DiffFit', create=True) + + if df is not None: + if df.interactive_fit_result_ready: + df.session.logger.error("You have run the fitting in Interactive mode. " + "Please run the following command: \n\n" + "close session\n\n" + "and then launch DiffFit again to run the fitting in Disk mode.") + return + + _save_results = True + _out_dir_exist_ok = True + _out_dir = out_dir + + _use_device = "cpu" + if torch.cuda.is_available(): + _use_device = "cuda:0" + if device is not None: + _use_device = device + + os.makedirs(_out_dir, exist_ok=_out_dir_exist_ok) + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"=======\n" + f"Wall clock time: {datetime.now()}\n" + f"-------\n" + f"Disk mode\n" + f"Structures Folder: {str_dir}\n" + f"Target Volume: {in_map}\n" + f"Target Surface Threshold: {level}\n" + f"-------\n" + + f"# positions: {num_positions}\n" + f"# rotations: {num_rotations}\n" + + f"Conv. loops: {smooth_loops}\n" + f"Conv. kernel sizes: \"{kernel_sizes}\"\n" + f"Conv. weights: \"{smooth_weights}\"\n" + + f"Gaussian mode: \"{Gaussian_mode}\"\n" + f"Fit atom mode: \"{fit_atom_mode}\"\n" + f"Negative space: \"{negative_space}\"\n" + + f"Learning rate: \"{learning_rate}\"\n" + f"# iters: \"{n_iters}\"\n" + + f"Out dir: \"{_out_dir}\"\n" + f"Device: \"{_use_device}\"\n" + f"-------\n") + + if df is not None: + df.disable_spheres_clicked() + + disk_fit_timer_start = datetime.now() + + print("Running the computation...") + + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"DiffFit optimization starts: {disk_fit_timer_start}\n") + + timer_start = datetime.now() + (target_vol_path, + target_surface_threshold, + mol_paths, + mol_centers, + e_sqd_log) = diff_atom_comp( + target_vol_path=in_map, + target_surface_threshold=level, + min_cluster_size=100, + structures_dir=str_dir, + fit_atom_mode=fit_atom_mode, + Gaussian_mode=Gaussian_mode, + N_shifts=num_positions, + N_quaternions=num_rotations, + negative_space_value=negative_space, + learning_rate=learning_rate, + n_iters=n_iters, + out_dir=_out_dir, + out_dir_exist_ok=_out_dir_exist_ok, + conv_loops=smooth_loops, + conv_kernel_sizes=ast.literal_eval(kernel_sizes), + conv_weights=ast.literal_eval(smooth_weights), + device=_use_device + ) + + timer_stop = datetime.now() + print(f"\nDiffFit optimization time elapsed: {timer_stop - timer_start}\n\n") + + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"-------\n" + f"DiffFit optimization time elapsed: {timer_stop - timer_start}\n") + + if df is not None: + # copy the directories + df.dataset_folder.setText(_out_dir) + + # output is tensor, convert to numpy + df.show_results(e_sqd_log.detach().cpu().numpy(), + mol_centers, + mol_paths, + target_vol_path, + target_surface_threshold) + df.tab_widget.setCurrentWidget(df.tab_view_group) + df.select_table_item(0) + run(session, "view orient") + + timer_stop = datetime.now() + print(f"\nDiffFit total time elapsed: {timer_stop - disk_fit_timer_start}\n\n") + + if df is not None: + metric_json = df.return_cluster_metric_json(0) + with open(f"{_out_dir}/log.log", "a") as log_file: + log_file.write(f"DiffFit total time elapsed: {timer_stop - disk_fit_timer_start}\n" + f"-------\n" + f"DiffFit top fit metric:\n" + f"{metric_json}\n" + f"=======\n\n") + + +dfit_disk_desc = CmdDesc(keyword=[("str_dir", OpenFolderNameArg), + ("in_map", OpenFileNameArg), + ("level", FloatArg), + + ("num_positions", IntArg), + ("num_rotations", IntArg), + + ("smooth_loops", StringArg), + ("kernel_sizes", StringArg), + ("smooth_weights", StringArg), + + ("Gaussian_mode", StringArg), + ("fit_atom_mode", StringArg), + ("negative_space", FloatArg), + + ("learning_rate", FloatArg), + ("n_iters", IntArg), + + ("out_dir", SaveFolderNameArg), + ("device", StringArg)], + required_arguments=["str_dir", + "in_map", + "level"]) + +# dfit disk str_dir dir sim_dir dir in map level 0.7 +# num_p 10 num_r 100 +# smooth_by smooth_loops smooth_weights kernel_sizes gaussian_mode +# fit_atom_mode out_dir device +# negative_space +# learning_rate +# n_iters + + +def dfit_path(session, mode, res_dir, old_path=None, new_path=None): + """Change the volume map's path in DiffFit results""" + import numpy as np + fit_res = np.load(f"{res_dir}/fit_res.npz") + target_vol_path = fit_res['target_vol_path'] + target_surface_threshold = fit_res['target_surface_threshold'] + mol_paths = fit_res['mol_paths'] + mol_centers = fit_res['mol_centers'] + opt_res = fit_res['opt_res'] + + if mode == "show": + print(f"Vol path: {target_vol_path}") + print(f"Mol paths: {mol_paths}") + + elif mode == "vol": + if new_path is None: + session.logger.error("You must specify a new path.") + return + + target_vol_path = new_path + print(f"New Vol path: {new_path}") + + elif mode == "mol": + if old_path is None or new_path is None: + session.logger.error("You must specify both old and new paths.") + return + + # Replace the old path with the new path in mol_paths array + mol_paths = np.array([path.replace(old_path, new_path) if old_path in path else path for path in mol_paths]) + print(f"Replaced \"{old_path}\" with \"{new_path}\" in the molecules' path.") + + np.savez(f"{res_dir}/fit_res.npz", + target_vol_path=target_vol_path, + target_surface_threshold=target_surface_threshold, + mol_paths=mol_paths, + mol_centers=mol_centers, + opt_res=opt_res) + + +dfit_path_desc = CmdDesc(required=[("mode", StringArg), + ("res_dir", OpenFolderNameArg)], + keyword=[("old_path", StringArg), + ("new_path", StringArg)]) + +# ========================================================================== +# Functions intended only for internal use by bundle +# ========================================================================== + + +def _create_volume_conv_list(session, vol, smooth_by, smooth_loops, smooth_kernel_sizes, Gaussian_mode, device, negative_space_value=-0.5): + # From here on, there are three strategies for utilizing gaussian smooth + # 1. with increasing sDev on the same input volume + # 2. with the same sDev iteratively + # Combine 1 & 2 + # Need to do experiment to see which one is better + + volume_conv_list = [None] * (smooth_loops + 1) + volume_conv_list[0] = vol.full_matrix() + + if smooth_by == "PyTorch iterative Gaussian": + volume_conv_list[0], _ = numpy2tensor(volume_conv_list[0], device) + volume_conv_list[0] = linear_norm_tensor(volume_conv_list[0]) + volume_conv_list = conv_volume(volume_conv_list[0], + device, + smooth_loops, + ast.literal_eval(smooth_kernel_sizes), + negative_space_value=negative_space_value, + kernel_type="Gaussian", + mode=Gaussian_mode) + volume_conv_list = [v.squeeze().detach().cpu().numpy() for v in volume_conv_list] + elif smooth_by == "ChimeraX incremental Gaussian": + for conv_idx in range(1, smooth_loops + 1): + vol_gaussian = run(session, f"volume gaussian #{vol.id[0]} sDev {conv_idx}") + + vol_device, _ = numpy2tensor(vol_gaussian.full_matrix(), device) + vol_device = linear_norm_tensor(vol_device) + + eligible_volume_tensor = vol_device > 0.0 + vol_device[~eligible_volume_tensor] = negative_space_value + + volume_conv_list[conv_idx] = vol_device.squeeze().detach().cpu().numpy() + + vol_gaussian.delete() + elif smooth_by == "ChimeraX iterative Gaussian": + kernel_sizes = ast.literal_eval(smooth_kernel_sizes) + vol_current = vol + for conv_idx in range(1, smooth_loops + 1): + vol_gaussian = run(session, f"volume gaussian #{vol_current.id[0]} sDev {kernel_sizes[conv_idx - 1]}") + + if conv_idx > 1: + vol_current.delete() + + vol_device, _ = numpy2tensor(vol_gaussian.full_matrix(), device) + vol_device = linear_norm_tensor(vol_device) + + eligible_volume_tensor = vol_device > 0.0 + vol_device[~eligible_volume_tensor] = negative_space_value + + volume_conv_list[conv_idx] = vol_device.squeeze().detach().cpu().numpy() + + vol_current = vol_gaussian + + vol_current.delete() + + + return volume_conv_list \ No newline at end of file diff --git a/src/parse_log.py b/src/parse_log.py index 855b025..8fe40b6 100644 --- a/src/parse_log.py +++ b/src/parse_log.py @@ -49,7 +49,7 @@ def animate_MQS(e_sqd_log, mol_folder, MQS, session, clean_scene=True): for structure in structures: structure.delete() - mol_files = os.listdir(mol_folder) + mol_files = sorted(os.listdir(mol_folder)) mol_path = os.path.join(mol_folder, mol_files[MQS[0]]) mol = run(session, f"open {mol_path}")[0] @@ -69,7 +69,7 @@ def animate_MQS_2(e_sqd_log, mol_folder, MQS, session, clean_scene=True): for structure in structures: structure.delete() - mol_files = os.listdir(mol_folder) + mol_files = sorted(os.listdir(mol_folder)) mol_path = os.path.join(mol_folder, mol_files[MQS[0]]) mol = run(session, f"open {mol_path}")[0] @@ -110,7 +110,7 @@ def look_at_MQS_idx(e_sqd_log, mol_folder, MQS, session, clean_scene=True): for structure in structures: structure.delete() - mol_files = os.listdir(mol_folder) + mol_files = sorted(os.listdir(mol_folder)) look_at_mol_idx, transformation = get_transformation_at_MQS(e_sqd_log, MQS) @@ -131,7 +131,7 @@ def look_at_cluster(e_sqd_clusters_ordered, mol_folder, cluster_idx, session, cl for structure in structures: structure.delete() - mol_files = os.listdir(mol_folder) + mol_files = sorted(os.listdir(mol_folder)) # mol_files[idx] pairs with e_sqd_log[idx] look_at_mol_idx, transformation = get_transformation_at_idx(e_sqd_clusters_ordered, cluster_idx) @@ -180,7 +180,7 @@ def simulate_volume(session, vol, mol_paths, mol_idx, transformation, res=4.0): mol.atoms.transform(transformation) from chimerax.map.molmap import molecule_map - mol_vol = molecule_map(session, mol.atoms, res, grid_spacing=vol.data_origin_and_step()[1][0]) + mol_vol = molecule_map(session, mol.atoms, res, grid_spacing=vol.data.step) mol.delete() return mol_vol diff --git a/src/q_shells.py b/src/q_shells.py new file mode 100644 index 0000000..6f0b9ce --- /dev/null +++ b/src/q_shells.py @@ -0,0 +1,510 @@ +import torch +import numpy as np +from datetime import datetime +from .DiffAtomComp import quaternion_to_matrix_batch, normalize_coordinates_to_map_origin_torch + +def unit_sphere_vertices(num_vertices): + from chimerax.surface.shapes import sphere_geometry2 + sphere_vertices = sphere_geometry2(2*num_vertices-4)[0] # 128 points evenly distributed around a unit sphere centred on (0,0,0) + return sphere_vertices + + +RANDOM_SEED=1985 + + +def generate_q_shells(mol, + points_per_shell=8, max_rad=2.0, step=0.1, num_test_points=128, + clustering_iterations=5, include_h=False, randomize_shell_points=True, random_seed=RANDOM_SEED): + ''' + Implementation of the map-model Q-score as described in Pintille et al. (2020): https://www.nature.com/articles/s41592-020-0731-1. + + If the model is well-fitted to the map, the Q-score is essentially an atom-by-atom estimate of "resolvability" + (i.e. how much the map tells us about the true position of the atom). If the model is *not* well-fitted, then + low Q-scores are good pointers to possible problem regions. In practice, of course, usually we have a mixture of + both situations. + + This version of the algorithm has a few minor modifications compared to the original implementation aimed at + improving overall speed. As a result the scores it returns are not identical to the original, typically differing by + +/- 0.04 in individual atom scores and +/- 0.02 in residue averages. This difference can be explained by the different + choice of test points, and reflects the underlying sampling uncertainty in the method. + + This implementation works as follows: + + - For each atom, define a set of shells in steps of `step` angstroms out to `max_rad` angstroms. + - For each shell, try to find at least `points_per_shell` probe points closer to the test atom than + any other atom: + + - For radii smaller than about half a bond length, first try a set of `points_per_shell` + points evenly spread around the spherical surface. + - For larger radii (or if this quick approach fails to find enough points on smaller radii), + start with `num_test_points` evenly spread on the sphere surface, remove points closer to other atoms + than the test atom. If more than `points_per_shell` points remain, perform up to `clustering_iterations` + rounds of k-means clustering to find `points_per_shell` clusters, and choose the point nearest to the + centroid of each cluster. If <= `points_per_shell` points remain, just do the best with what we have. + By default, the "seed" centroids for each cluster are chosen pseudo-randomly from the input points. + Using the same `random_seed` will give the same result each time; varying `random_seed` over multiple + runs can be used to give an idea of the underlying uncertainty in the algorithm. If `randomize_shell_points` + is False, the seed centroids will instead be the closest point (in spherical coordinates) to each of + `points_per_shell` evenly-spaced points on a unit sphere. While this may intuitively seem preferable, + in practice for tightly-packed atoms it leads to oversampling of the "junctions" with other atoms, and + undersampling of the unhindered space. + + Returns: + + - a numpy array with q shells coordinates + - radii + ''' + global_timer_start = datetime.now() + + from chimerax.geometry import find_close_points, find_closest_points, Places + from chimerax.atomic import Residues + import numpy as np + + from chimerax.qscore import _kmeans + + pps_vertices = unit_sphere_vertices(points_per_shell) + ref_sphere_vertices = unit_sphere_vertices(num_test_points) + ref_sphere_vertices_large = unit_sphere_vertices(num_test_points * 4) + ref_sphere_vertices_huge = unit_sphere_vertices(num_test_points * 20) + + radii = np.arange(0, max_rad + step / 2, step) + + query_atoms = mol.atoms + + if not include_h: + query_atoms = query_atoms[query_atoms.element_names != 'H'] + + + query_coords = query_atoms.scene_coords + + + query_atoms_center = [] + query_atoms_points = [] + not_full_shells = 0 + + q_scores = [] + for i, a in enumerate(query_atoms): + + not_full_flag = False + + a_coord = a.scene_coord + _, nearby_i = find_close_points([a_coord], query_coords, max_rad * 3) + nearby_a = query_atoms[nearby_i] + ai = nearby_a.index(a) + nearby_coords = nearby_a.scene_coords + shell_rad = step + local_d_vals = {} + + shell_points = [] + + j = 1 + while shell_rad < max_rad + step / 2: + local_pps = (pps_vertices * shell_rad) + a_coord + if shell_rad < 0.7: # about half a C-C bond length + # Try the quick way first (should succeed for almost all cases unless geometry is seriously wonky) + i1, i2, near1 = find_closest_points(local_pps, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + if len(candidates) == points_per_shell: + shell_rad += step + j += 1 + + shell_points.append(local_pps) + + continue + + local_sphere = (ref_sphere_vertices * shell_rad) + a_coord + i1, i2, near1 = find_closest_points(local_sphere, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + + if len(candidates) < points_per_shell: + + local_sphere = (ref_sphere_vertices_large * shell_rad) + a_coord + i1, i2, near1 = find_closest_points(local_sphere, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + + if len(candidates) < points_per_shell: + local_sphere = (ref_sphere_vertices_huge * shell_rad) + a_coord + i1, i2, near1 = find_closest_points(local_sphere, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + + if len(candidates) < points_per_shell: + not_full_shells += 1 + not_full_flag = True + + else: + points = local_sphere[candidates] + if not randomize_shell_points: + labels, closest = _kmeans.spherical_k_means_defined(points, a_coord, points_per_shell, + local_pps, clustering_iterations) + else: + labels, closest = _kmeans.spherical_k_means_random(points, a_coord, points_per_shell, + clustering_iterations, random_seed + j) + + points = points[closest] + + else: + points = local_sphere[candidates] + if not randomize_shell_points: + labels, closest = _kmeans.spherical_k_means_defined(points, a_coord, points_per_shell, + local_pps, clustering_iterations) + else: + labels, closest = _kmeans.spherical_k_means_random(points, a_coord, points_per_shell, + clustering_iterations, random_seed + j) + + points = points[closest] + + else: + points = local_sphere[candidates] + if not randomize_shell_points: + labels, closest = _kmeans.spherical_k_means_defined(points, a_coord, points_per_shell, + local_pps, clustering_iterations) + else: + labels, closest = _kmeans.spherical_k_means_random(points, a_coord, points_per_shell, + clustering_iterations, random_seed + j) + + points = points[closest] + + + shell_rad += step + j += 1 + + shell_points.append(points) + + if not_full_flag: + continue + + query_atoms_center.append(a_coord) + query_atoms_points.append(shell_points) + + if not_full_shells: + print(f"Not_full_shells: {not_full_shells}") + + query_atoms_points_array = np.stack( + [np.concatenate(atom_shell_points, axis=0) for atom_shell_points in query_atoms_points]) + query_atoms_center_np = np.stack(query_atoms_center) + query_atoms_center_repeated = np.repeat(query_atoms_center_np[:, np.newaxis, :], points_per_shell, axis=1) + q_shell_coords = np.concatenate([query_atoms_center_repeated, query_atoms_points_array], axis=1) + + print(f"Generate q shells timer: {datetime.now() - global_timer_start}") + + return q_shell_coords, radii + + +def min_max_d(v): + m = v.data.full_matrix() + mean, sd, _ = v.mean_sd_rms() + max_d = min(mean + sd * 10, m.max()) + min_d = max(mean - sd, m.min()) + return min_d, max_d + + +def q_scores_for_clusters(q_shell_coords_torch_list, q_shell_radii_np_list, volume, fit_res_clusters, fit_res_all, + ref_sigma=0.6, points_per_shell=8, device="cuda", + save_log=False, + log_path=""): + + shifts = [] + quaternions = [] + + for cluster_idx in range(len(fit_res_clusters)): + mol_idx = int(fit_res_clusters[cluster_idx, 0]) + record_idx = int(fit_res_clusters[cluster_idx, 1]) + iter_idx = int(fit_res_clusters[cluster_idx, 2]) + + shift = fit_res_all[mol_idx, record_idx, iter_idx, :3] + quat = np.concatenate( + ([fit_res_all[mol_idx, record_idx, iter_idx, 6]], -fit_res_all[mol_idx, record_idx, iter_idx, 3:6]), + axis=0) + + shifts.append(shift) + quaternions.append(quat) + + shifts = np.stack(shifts) + quaternions = np.stack(quaternions) + shifts = torch.tensor(shifts, device=device).float() + quaternions = torch.tensor(quaternions, device=device).float() + quaternions_matrices = quaternion_to_matrix_batch(quaternions.unsqueeze(0)) + + + vol_matrix = volume.full_matrix() + target_origin = np.array(volume.data.origin) + target_steps = np.array(volume.data.step) + target_no_negative = vol_matrix + + target = torch.tensor(target_no_negative, device=device).float() + target_dim = target.shape + + import operator + target_size = np.array(list(map(operator.mul, target_dim, target_steps))) + + target_size_x_y_z = [target_size[2], target_size[1], target_size[0]] + target_size_x_y_z_tensor = torch.tensor(target_size_x_y_z, device=device).float() + target_origin_tensor = torch.tensor(target_origin, device=device).float() + + target = target.unsqueeze(0).unsqueeze(0) + + + min_d, max_d = min_max_d(volume) + a = max_d - min_d + b = min_d + + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"Starting to calculate Q-scores for {len(fit_res_clusters)} fits: {datetime.now()}\n" + f"-------\n") + + q_scores = [] + for cluster_idx in range(len(fit_res_clusters)): + mol_idx = int(fit_res_clusters[cluster_idx, 0]) + + # q ref + num_shells = len(q_shell_radii_np_list[mol_idx]) + q_reference_gaussian = a * np.exp(-0.5 * (q_shell_radii_np_list[mol_idx] / ref_sigma) ** 2) + b + q_ref = np.concatenate([[q_reference_gaussian[0]] * points_per_shell, + *[[q_reference_gaussian[j]] * points_per_shell for j in range(num_shells - 1)]]) + q_ref = torch.tensor(q_ref, device=device, dtype=torch.float32) + q_ref -= q_ref.mean() + + # q shells q scores + q_shell_coords = q_shell_coords_torch_list[mol_idx] + transformed_coords = torch.matmul(q_shell_coords, quaternions_matrices[:, cluster_idx:cluster_idx + 1, :, :]) + + transformed_coords += shifts[cluster_idx, :] + + q_shell_coords_normalized_to_target = normalize_coordinates_to_map_origin_torch(transformed_coords, + target_size_x_y_z_tensor, + target_origin_tensor) + + q_shell_coords_normalized_to_target = q_shell_coords_normalized_to_target.reshape([1, 1, -1, 168, 3]) + + q_measure = torch.nn.functional.grid_sample(target, q_shell_coords_normalized_to_target, 'bilinear', 'border', + align_corners=True) + + q_measure.squeeze_() + + q_measure -= q_measure.mean(dim=-1, keepdim=True) + + inner_product = torch.matmul(q_measure, q_ref) + q_measure_l2 = torch.norm(q_measure, p=2, dim=-1) + q_ref_l2 = torch.norm(q_ref, p=2, dim=-1) + q_score_torch = inner_product / (q_measure_l2 * q_ref_l2) + + q_scores.append(q_score_torch[~torch.isnan(q_score_torch)].mean()) + + if save_log and (cluster_idx + 1) % 1000 == 0: + with open(log_path, "a") as log_file: + log_file.write(f"Q-scoring {cluster_idx + 1} fits: {datetime.now()}\n") + + q_scores_tensor = torch.stack(q_scores) + # top_10_values, top_10_indices = torch.topk(q_scores_tensor, k=10) + # + # # Display the results + # print("Top 10 values:", top_10_values) + # print("Indices of top 10 values + 1:", top_10_indices + 1) + # + # return top_10_values, top_10_indices + + return q_scores_tensor.cpu().numpy() + + +from concurrent.futures import ThreadPoolExecutor +def process_atom(atom_idx, query_atoms, query_coords, + pps_vertices, + ref_sphere_vertices, ref_sphere_vertices_large, ref_sphere_vertices_huge, + step, max_rad, points_per_shell, + clustering_iterations, + randomize_shell_points, random_seed): + from chimerax.geometry import find_close_points, find_closest_points + from chimerax.qscore import _kmeans + + not_full_flag = False + + a = query_atoms[atom_idx] + a_coord = query_coords[atom_idx] + _, nearby_i = find_close_points([a_coord], query_coords, max_rad * 3) + nearby_a = query_atoms[nearby_i] + ai = nearby_a.index(a) + nearby_coords = nearby_a.scene_coords + shell_rad = step + local_d_vals = {} + + shell_points = [] + + j = 1 + while shell_rad < max_rad + step / 2: + local_pps = (pps_vertices * shell_rad) + a_coord + if shell_rad < 0.7: # about half a C-C bond length + # Try the quick way first (should succeed for almost all cases unless geometry is seriously wonky) + i1, i2, near1 = find_closest_points(local_pps, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + if len(candidates) == points_per_shell: + shell_rad += step + j += 1 + + shell_points.append(local_pps) + + continue + + local_sphere = (ref_sphere_vertices * shell_rad) + a_coord + i1, i2, near1 = find_closest_points(local_sphere, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + + if len(candidates) < points_per_shell: + + local_sphere = (ref_sphere_vertices_large * shell_rad) + a_coord + i1, i2, near1 = find_closest_points(local_sphere, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + + if len(candidates) < points_per_shell: + local_sphere = (ref_sphere_vertices_huge * shell_rad) + a_coord + i1, i2, near1 = find_closest_points(local_sphere, nearby_coords, shell_rad * 1.5) + closest = near1 + candidates = i1[closest == ai] + + if len(candidates) < points_per_shell: + not_full_flag = True + + else: + points = local_sphere[candidates] + if not randomize_shell_points: + labels, closest = _kmeans.spherical_k_means_defined(points, a_coord, points_per_shell, + local_pps, clustering_iterations) + else: + labels, closest = _kmeans.spherical_k_means_random(points, a_coord, points_per_shell, + clustering_iterations, random_seed + j) + + points = points[closest] + + else: + points = local_sphere[candidates] + if not randomize_shell_points: + labels, closest = _kmeans.spherical_k_means_defined(points, a_coord, points_per_shell, + local_pps, clustering_iterations) + else: + labels, closest = _kmeans.spherical_k_means_random(points, a_coord, points_per_shell, + clustering_iterations, random_seed + j) + + points = points[closest] + + else: + points = local_sphere[candidates] + if not randomize_shell_points: + labels, closest = _kmeans.spherical_k_means_defined(points, a_coord, points_per_shell, + local_pps, clustering_iterations) + else: + labels, closest = _kmeans.spherical_k_means_random(points, a_coord, points_per_shell, + clustering_iterations, random_seed + j) + + points = points[closest] + + shell_rad += step + j += 1 + + shell_points.append(points) + + return None if not_full_flag else (a_coord, shell_points) + + +def generate_q_shells_parallel(mol, + points_per_shell=8, max_rad=2.0, step=0.1, + num_test_points=128, clustering_iterations=5, + include_h=False, randomize_shell_points=True, random_seed=RANDOM_SEED): + global_timer_start = datetime.now() + + query_atoms = mol.atoms + if not include_h: + query_atoms = query_atoms[query_atoms.element_names != 'H'] + query_coords = query_atoms.scene_coords + + pps_vertices = unit_sphere_vertices(points_per_shell) + ref_sphere_vertices = unit_sphere_vertices(num_test_points) + ref_sphere_vertices_large = unit_sphere_vertices(num_test_points * 4) + ref_sphere_vertices_huge = unit_sphere_vertices(num_test_points * 20) + + # Multithreading + results = [] + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + process_atom, atom_idx, query_atoms, query_coords, pps_vertices, ref_sphere_vertices, + ref_sphere_vertices_large, ref_sphere_vertices_huge, step, max_rad, points_per_shell, + clustering_iterations, randomize_shell_points, random_seed + ) + for atom_idx in range(len(query_atoms)) + ] + for future in futures: + result = future.result() + if result is not None: + results.append(result) + + query_atoms_center, query_atoms_points = zip(*results) + query_atoms_points_array = np.stack( + [np.concatenate(atom_shell_points, axis=0) for atom_shell_points in query_atoms_points]) + query_atoms_center_np = np.stack(query_atoms_center) + query_atoms_center_repeated = np.repeat(query_atoms_center_np[:, np.newaxis, :], points_per_shell, axis=1) + q_shell_coords = np.concatenate([query_atoms_center_repeated, query_atoms_points_array], axis=1) + + print(f"Generate q shells timer: {datetime.now() - global_timer_start}") + + return q_shell_coords, np.arange(0, max_rad + step / 2, step) + + +def generate_q_shells_simple(mol, + points_per_shell=8, max_rad=2.0, step=0.1, num_test_points=128, + clustering_iterations=5, include_h=False, randomize_shell_points=True, random_seed=RANDOM_SEED): + ''' + Returns: + + - a numpy array with q shells coordinates + - radii + ''' + from datetime import datetime + global_timer_start = datetime.now() + + from chimerax.geometry import find_close_points, find_closest_points, Places + import numpy as np + + + pps_vertices = unit_sphere_vertices(points_per_shell) + + query_atoms = mol.atoms + + if not include_h: + query_atoms = query_atoms[query_atoms.element_names != 'H'] + + query_coords = query_atoms.scene_coords + + query_atoms_center = [] + query_atoms_points = [] + + for i, a in enumerate(query_atoms): + + a_coord = a.scene_coord + shell_rad = step + shell_points = [] + + j = 1 + while shell_rad < max_rad + step / 2: + local_pps = (pps_vertices * shell_rad) + a_coord + shell_points.append(local_pps) + shell_rad += step + + query_atoms_center.append(a_coord) + query_atoms_points.append(shell_points) + + query_atoms_points_array = np.stack( + [np.concatenate(atom_shell_points, axis=0) for atom_shell_points in query_atoms_points]) + query_atoms_center_np = np.stack(query_atoms_center) + query_atoms_center_repeated = np.repeat(query_atoms_center_np[:, np.newaxis, :], points_per_shell, axis=1) + q_shell_coords = np.concatenate([query_atoms_center_repeated, query_atoms_points_array], axis=1) + + print(f"Generate q shells timer: {datetime.now() - global_timer_start}") + + return q_shell_coords, np.arange(0, max_rad + step / 2, step) \ No newline at end of file diff --git a/src/run.py b/src/run.py index 7654a9e..74d2ead 100644 --- a/src/run.py +++ b/src/run.py @@ -14,7 +14,7 @@ os.makedirs(out_base_dir, exist_ok=True) # Loop through all files in the segmented_maps directory -for map_file in os.listdir(segmented_maps_dir): +for map_file in sorted(os.listdir(segmented_maps_dir)): if map_file.endswith(".mrc"): # Construct the full path to the map file target_vol = os.path.join(segmented_maps_dir, map_file) diff --git a/src/split_chains.py b/src/split_chains.py index 5232810..53fee9c 100644 --- a/src/split_chains.py +++ b/src/split_chains.py @@ -8,6 +8,7 @@ # Import necessary modules import os, sys +from pathlib import Path from chimerax.core.commands import run from datetime import datetime import numpy as np @@ -25,7 +26,7 @@ structure = run(session, f'open {input_model}')[0] # Get the base name of the input structure for naming output files -structure_basename = os.path.basename(input_model).split('.')[0] +structure_basename = Path(input_model).stem # Iterate over each chain in the structure chain_id_name_list = [] diff --git a/src/tablemodel.py b/src/tablemodel.py index dea14fb..dd16f19 100644 --- a/src/tablemodel.py +++ b/src/tablemodel.py @@ -3,17 +3,19 @@ class TableModel(QAbstractTableModel): """A model to interface a Qt view with pandas dataframe """ - def __init__(self, sqd_cluster_data, sqd_data, mol_paths, parent=None): + def __init__(self, sqd_cluster_data, sqd_data, mol_paths, mol_num_atoms, parent=None): QAbstractTableModel.__init__(self, parent) self._sqd_data = sqd_data self._sqd_cluster_data = sqd_cluster_data import os self._mol_names = [os.path.splitext(os.path.basename(path))[0] for path in mol_paths] + self._mol_num_atoms = mol_num_atoms - - self._header = ["Id", "Mol name", "Hits", - "Density", "Overlap", "Correlation", "Cam", "Inside"] + self._header = ["Id", "Mol name", "Q-score", "Hits", "# Atoms", + "Density (normalized)", "Overlap", "Correlation", "Cam", "Inside", + "Avg Density (in)", "Avg Density (all)", + "DF CID"] # mapping of columns (from view to data) # self._mapping = [-1, -1, 10, 11, 12, 13] @@ -59,11 +61,18 @@ def data(self, index: QModelIndex, role=Qt.ItemDataRole): elif column == 1: return str(f"{mol_idx}-{self._mol_names[mol_idx]}") elif column == 2: - return int(self._sqd_cluster_data[index.row(), 3]) - elif 3 <= column <= 7: + try: + return float(round(float(self._sqd_cluster_data[index.row(), 3]) * 10000)) / 10000.0 # for 4 decimals + except: + return None + elif column == 3: + return int(self._sqd_cluster_data[index.row(), 4]) + elif column == 4: + return int(self._mol_num_atoms[mol_idx]) + elif 5 <= column <= 12: record_row = self._sqd_data[mol_idx, record_idx, iter_idx] - return float(round(float(record_row[index.column() + 4]) * 10000)) / 10000.0 # for 4 decimals - # return float(record_row[index.column() + 4]) + return float(round(float(record_row[column + 2]) * 10000)) / 10000.0 # for 4 decimals + # return float(record_row[index.column() + 2]) return None diff --git a/src/tool.py b/src/tool.py index 05d982f..219d1ed 100644 --- a/src/tool.py +++ b/src/tool.py @@ -40,12 +40,66 @@ import sys import numpy as np import os +from pathlib import Path import torch import psutil import platform import ast from scipy.interpolate import interp1d - + + +def calculate_candidate_indices(q_scores, num_test_fits=20): + """Calculates the candidate indices based on Q-scores.""" + try: + # Get the indices of the top 20 values (num_test_fits) in descending order + top_fits_indices = np.argpartition(q_scores, -num_test_fits)[-num_test_fits:] + top_fits_values = q_scores[top_fits_indices] + + # Sort the top values and their indices in descending order + sorted_indices_desc = np.argsort(-top_fits_values) + + # Calculate consecutive differences and negate them + consecutive_differences = -np.diff(top_fits_values[sorted_indices_desc]) + std = np.std(consecutive_differences) + largest_gap_index = np.argmax(consecutive_differences) + largest_gap_ratio = consecutive_differences[largest_gap_index] / std + + if largest_gap_ratio > 1: + return np.array(range(largest_gap_index + 1)) + else: + return np.array([]) + except Exception as e: + print(f"Warn: no candidates id: {e}") + return np.array([]) + + +def generate_q_shells_wrapper(q_shell_generator, mol_path, q_shells_ext, session): + print(f"Q shell generator: {q_shell_generator}: {mol_path}") + + mol_basename = Path(mol_path).stem + mol_folder = os.path.dirname(mol_path) + + mol = run(session, f'open {mol_path}')[0] + + # center mol + from chimerax.geometry import Place + mol_center = mol.atoms.coords.mean(axis=0) + transform = Place(origin=-mol_center) + mol.atoms.transform(transform) + mol.position = Place() + + q_shell_coords, radii = q_shell_generator(mol) + + q_shells_filepath = os.path.join(mol_folder, f"{mol_basename}.{q_shells_ext}") + np.savez_compressed(q_shells_filepath, + q_shell_coords=q_shell_coords, + radii=radii) # Not atomic warning: q_scores_points_per_shell not saved!!! + + run(session, f"close #{mol.id[0]}") + + return q_shell_coords, radii + + def create_row(parent_layout, left=0, top=0, right=0, bottom=0, spacing=5): row_frame = QFrame() @@ -95,17 +149,15 @@ def interp_backbone_for_mol(mol): class DiffFitSettings: def __init__(self): # viewing - self.view_output_directory: str = "D:\\GIT\\DiffFit\\dev_data\\output" - self.view_target_vol_path: str = "D:\\GIT\\DiffFit\\dev_data\\input\\domain_fit_demo_3domains\\density2.mrc" - self.view_structures_directory: str = "D:\\GIT\\DiffFit\dev_data\input\domain_fit_demo_3domains\subunits_cif" + self.view_output_directory: str = "D:/GIT/DiffFit/dev_data/output" + self.view_structures_directory: str = "D:/GIT/DiffFit/dev_data/input/domain_fit_demo_3domains/subunits_cif" # computing - self.input_directory: str = "D:\\GIT\\DiffFit\\dev_data\\input\\domain_fit_demo_3domains" - self.target_vol_path: str = "D:\\GIT\\DiffFit\\dev_data\\input\\domain_fit_demo_3domains\\density2.mrc" - self.structures_directory: str = "D:\\GIT\\DiffFit\dev_data\input\domain_fit_demo_3domains\subunits_cif" - self.structures_sim_map_dir: str = "D:\\GIT\\DiffFit\dev_data\input\domain_fit_demo_3domains\subunits_mrc" + self.input_directory: str = "D:/GIT/DiffFit/dev_data/input/domain_fit_demo_3domains" + self.target_vol_path: str = "D:/GIT/DiffFit/dev_data/input/domain_fit_demo_3domains/density2.mrc" + self.structures_directory: str = "D:/GIT/DiffFit/dev_data/input/domain_fit_demo_3domains/subunits_cif" - self.output_directory: str = "D:\\GIT\\DiffFit\\dev_data\\output" + self.output_directory: str = "D:/GIT/DiffFit/dev_data/output" self.target_surface_threshold: float = 2.0 self.min_cluster_size: float = 100 @@ -125,6 +177,28 @@ def __init__(self): self.clustering_in_contour_threshold: float = 0.2 self.clustering_correlation_threshold: float = 0.5 + self.df_cid_threshold = 0.15 + + +class DiffFitTableView(QTableView): + def __init__(self, parent=None): + super().__init__(parent) + self.parent = parent + + self.up_down_key_callback = None # Callback function to call on key press + + def keyPressEvent(self, event): + super().keyPressEvent(event) + if event.key() == Qt.Key.Key_Up or event.key() == Qt.Key.Key_Down: + # Trigger the callback function if registered + if self.up_down_key_callback: + current_index = self.currentIndex() + self.up_down_key_callback(current_index) + + def setUpDownKeyCallback(self, callback): + """Register a callback function to be called on key press.""" + self.up_down_key_callback = callback + class DiffFitTool(ToolInstance): @@ -178,7 +252,7 @@ def __init__(self, session, tool_name): self.interactive_fit_result_ready = False self.fit_result = None - self.mol_centers = None + self.mol_num_atoms = None self.cluster_color_map = {} @@ -189,6 +263,7 @@ def __init__(self, session, tool_name): self.spheres = None self.proxyModel = None + self.mol = None def _build_ui(self): @@ -262,7 +337,6 @@ def load_settings(self): #compute self.target_vol_path.setText(self.settings.target_vol_path) self.structures_dir.setText(self.settings.structures_directory) - self.structures_sim_map_dir.setText(self.settings.structures_sim_map_dir) self.out_dir.setText(self.settings.output_directory) self.target_surface_threshold.setValue(self.settings.target_surface_threshold) self.min_cluster_size.setValue(self.settings.min_cluster_size) @@ -276,10 +350,10 @@ def load_settings(self): self.conv_weights.setText("[{0}]".format(','.join(map(str, self.settings.conv_weights)))) # view - self.target_vol.setText(self.settings.view_target_vol_path) - self.dataset_folder.setText(self.settings.view_output_directory) + self.dataset_folder.setText(self.settings.view_output_directory) # clustering + self.df_cid_threshold.setValue(self.settings.df_cid_threshold) self.clustering_in_contour_threshold.setValue(self.settings.clustering_in_contour_threshold) self.clustering_correlation_threshold.setValue(self.settings.clustering_correlation_threshold) self.clustering_angle_tolerance.setValue(self.settings.clustering_angle_tolerance) @@ -296,7 +370,6 @@ def store_settings(self): #compute self.settings.target_vol_path = self.target_vol_path.text() self.settings.structures_directory = self.structures_dir.text() - self.settings.structures_sim_map_dir = self.structures_sim_map_dir.text() self.settings.output_directory = self.out_dir.text() self.settings.target_surface_threshold = self.target_surface_threshold.value() self.settings.min_cluster_size = self.min_cluster_size.value() @@ -315,16 +388,14 @@ def store_settings(self): #view self.settings.view_output_directory = self.dataset_folder.text() - self.settings.view_target_vol_path = self.target_vol.text() - + # clustering + self.settings.df_cid_threshold = self.df_cid_threshold.value() self.settings.clustering_in_contour_threshold = self.clustering_in_contour_threshold.value() self.settings.clustering_correlation_threshold = self.clustering_correlation_threshold.value() self.settings.clustering_angle_tolerance = self.clustering_angle_tolerance.value() self.settings.clustering_shift_tolerance = self.clustering_shift_tolerance.value() - #print(self.settings) - #print(self.settings.view_target_vol_path) def build_single_fit_ui(self, layout): row = QHBoxLayout() @@ -537,18 +608,7 @@ def build_compute_ui(self, layout): layout.addWidget(self.structures_dir, row, 1) layout.addWidget(structures_dir_select, row, 2) row = row + 1 - - structures_sim_map_dir_label = QLabel() - structures_sim_map_dir_label.setText("Structures Sim-map Folder:") - self.structures_sim_map_dir = QLineEdit() - self.structures_sim_map_dir.textChanged.connect(lambda: self.store_settings()) - structures_sim_map_dir_select = QPushButton("Select") - structures_sim_map_dir_select.clicked.connect(lambda: self.select_clicked("Structures Sim-map Folder", self.structures_sim_map_dir)) - layout.addWidget(structures_sim_map_dir_label, row, 0) - layout.addWidget(self.structures_sim_map_dir, row, 1) - layout.addWidget(structures_sim_map_dir_select, row, 2) - row = row + 1 - + out_dir_label = QLabel() out_dir_label.setText("Output Folder:") self.out_dir = QLineEdit() @@ -804,23 +864,11 @@ def build_utilities_ui(self, layout): row = row + 1 - doc_label = QLabel("Simulate a map for each structure in the folder") + doc_label = QLabel("Simulate a map for each structure in a folder") doc_label.setWordWrap(True) layout.addWidget(doc_label, row, 0, 1, 3) row = row + 1 - sim_out_dir_label = QLabel() - sim_out_dir_label.setText("Output Folder:") - self.sim_out_dir = QLineEdit() - self.sim_out_dir.setText("sim_out") - sim_out_dir_select = QPushButton("Select") - sim_out_dir_select.clicked.connect( - lambda: self.select_clicked("Output folder for the simulated maps", self.sim_out_dir)) - layout.addWidget(sim_out_dir_label, row, 0) - layout.addWidget(self.sim_out_dir, row, 1) - layout.addWidget(sim_out_dir_select, row, 2) - row = row + 1 - sim_dir_label = QLabel() sim_dir_label.setText("Structures Folder:") self.sim_dir = QLineEdit() @@ -850,6 +898,37 @@ def build_utilities_ui(self, layout): layout.addWidget(button, row, 2) row = row + 1 + doc_label = QLabel("Generate q shells for each structure in a folder") + doc_label.setWordWrap(True) + layout.addWidget(doc_label, row, 0, 1, 3) + row = row + 1 + + q_shells_label = QLabel() + q_shells_label.setText("Structures Folder:") + self.q_shells_dir = QLineEdit() + self.q_shells_dir.setText("split_out") + q_shells_dir_select = QPushButton("Select") + q_shells_dir_select.clicked.connect( + lambda: self.select_clicked("Folder containing the structures", self.q_shells_dir)) + layout.addWidget(q_shells_label, row, 0) + layout.addWidget(self.q_shells_dir, row, 1) + layout.addWidget(q_shells_dir_select, row, 2) + row = row + 1 + + q_shells_mode_label = QLabel() + q_shells_mode_label.setText("Mode:") + button = QPushButton() + button.setText("Full (non-overlapping)") + button.clicked.connect(lambda: self.q_shell_button_clicked("Full")) + layout.addWidget(q_shells_mode_label, row, 0) + layout.addWidget(button, row, 1) + + button = QPushButton() + button.setText("Simple") + button.clicked.connect(lambda: self.q_shell_button_clicked("Simple")) + layout.addWidget(button, row, 2) + row = row + 1 + vertical_spacer = QSpacerItem(1, 1, QSizePolicy.Minimum, QSizePolicy.Expanding) layout.addItem(vertical_spacer, row + 1, 0) @@ -908,7 +987,8 @@ def build_settings_ui(self, layout): layout.addLayout(row) doc_label = QLabel("If the map's resolution < 5.0, we suggest using \"Gaussian with negative (shrink)\".\n" - "Otherwise, we suggest using \"Gaussian then negative (expand)\".\n") + "Otherwise, we suggest give \"Gaussian then negative (expand)\" a try and see. " + "But the influence of this parameter is mild in most cases. \n") doc_label.setWordWrap(True) row.addWidget(doc_label) @@ -926,18 +1006,7 @@ def build_view_ui(self, layout): layout.addWidget(self._view_input_mode, row, 1, 1, 2) row = row + 1 - target_vol_label = QLabel("Target Volume:") - self.target_vol = QLineEdit() - self.target_vol.textChanged.connect(lambda: self.store_settings()) - self.target_vol.setEnabled(False) - self.target_vol_select = QPushButton("Select") - self.target_vol_select.setEnabled(False) - self.target_vol_select.clicked.connect(lambda: self.select_clicked("Target Volume", self.target_vol, False, "MRC Files(*.mrc);;MAP Files(*.map)")) - layout.addWidget(target_vol_label, row, 0) - layout.addWidget(self.target_vol, row, 1) - layout.addWidget(self.target_vol_select, row, 2) - row = row + 1 - + # data folder - where the data is stored dataset_folder_label = QLabel("Result Folder:") self.dataset_folder = QLineEdit() @@ -949,6 +1018,17 @@ def build_view_ui(self, layout): layout.addWidget(self.dataset_folder_select, row, 2) row = row + 1 + df_cid_threshold_label = QLabel() + df_cid_threshold_label.setText("DF CID threshold:") + self.df_cid_threshold = QDoubleSpinBox() + self.df_cid_threshold.setMinimum(-1.0) + self.df_cid_threshold.setMaximum(1.0) + self.df_cid_threshold.setSingleStep(0.01) + self.df_cid_threshold.valueChanged.connect(lambda: self.store_settings()) + layout.addWidget(df_cid_threshold_label, row, 0) + layout.addWidget(self.df_cid_threshold, row, 1, 1, 2) + row = row + 1 + clustering_in_contour_threshold_label = QLabel() clustering_in_contour_threshold_label.setText("In contour threshold:") self.clustering_in_contour_threshold = QDoubleSpinBox() @@ -1010,12 +1090,13 @@ def build_view_ui(self, layout): #layout.addWidget(self.line_edit) # table view of all the results - view = QTableView() + view = DiffFitTableView() view.resize(800, 500) view.horizontalHeader().setStretchLastSection(True) view.setAlternatingRowColors(True) view.setSelectionBehavior(QTableView.SelectRows) - view.clicked.connect(self.table_row_clicked) + view.clicked.connect(self.table_row_clicked) + view.setUpDownKeyCallback(self.table_row_clicked) layout.addWidget(view) self.view = view layout.addWidget(view, row, 0, 1, 3) @@ -1026,7 +1107,28 @@ def build_view_ui(self, layout): stats.setText("Stats: ") self.stats = stats layout.addWidget(stats, row, 0, 1, 3) - row = row + 1 + row = row + 1 + + # Adding "Candidates" field and Save button + candidates_folder_label = QLabel("Candidates Folder:") + self.candidates_folder = QLineEdit() + candidates_folder_select = QPushButton("Select") + candidates_folder_select.clicked.connect(lambda: self.select_clicked("Save candidates to", self.candidates_folder)) + layout.addWidget(candidates_folder_label, row, 0) + layout.addWidget(self.candidates_folder, row, 1) + layout.addWidget(candidates_folder_select, row, 2) + row = row + 1 + + candidates_label = QLabel("Candidates id: ") + self.candidates_id = QLineEdit() + self.candidates_id.setText("") + save_button = QPushButton("Save structures") + save_button.clicked.connect(self.save_candidates) + + layout.addWidget(candidates_label, row, 0) + layout.addWidget(self.candidates_id, row, 1) + layout.addWidget(save_button, row, 2) + row += 1 # button panel simulate_volume_label = QLabel("Resolution:") @@ -1198,14 +1300,10 @@ def _Gaussian_mode_changed(self): def _view_input_mode_changed(self): if self._view_input_mode.currentText() == "interactive": self.fit_input_mode = "interactive" - self.target_vol.setEnabled(False) - self.target_vol_select.setEnabled(False) self.dataset_folder.setEnabled(False) self.dataset_folder_select.setEnabled(False) elif self._view_input_mode.currentText() == "disk file": self.fit_input_mode = "disk file" - self.target_vol.setEnabled(False) - self.target_vol_select.setEnabled(False) self.dataset_folder.setEnabled(True) self.dataset_folder_select.setEnabled(True) @@ -1322,18 +1420,14 @@ def select_clicked(self, text, target, save = False, pattern = "dir"): ext = "" if save: - options = QFileDialog.Options() - options |= QFileDialog.DontUseNativeDialog - fileName, ext = QFileDialog.getSaveFileName(target, text, "", pattern, options = options) + fileName, ext = QFileDialog.getSaveFileName(target, text, "", pattern) ext = ext[-4:] ext = ext[:3] else: if pattern == "dir": fileName = QFileDialog.getExistingDirectory(target, text) elif len(pattern) > 0 : - options = QFileDialog.Options() - options |= QFileDialog.DontUseNativeDialog - fileName, ext = QFileDialog.getOpenFileName(target, text, "", pattern, options = options) + fileName, ext = QFileDialog.getOpenFileName(target, text, "", pattern) ext = ext[-4:] ext = ext[:3] @@ -1344,7 +1438,11 @@ def select_clicked(self, text, target, save = False, pattern = "dir"): return fileName, ext - def show_results(self, e_sqd_log, mol_centers, mol_paths, target_vol_path=None, target_surface_threshold=None): + def show_results(self, e_sqd_log, mol_num_atoms, mol_paths, + target_vol_path=None, + target_surface_threshold=None, + save_log=False, + log_path=""): if e_sqd_log is None: return @@ -1360,40 +1458,130 @@ def show_results(self, e_sqd_log, mol_centers, mol_paths, target_vol_path=None, self.vol = run(self.session, "open {0}".format(target_vol_path))[0] run(self.session,f"volume #{self.vol.id[0]} level {target_surface_threshold}") - # TODO: define mol_centers - elif self.fit_input_mode == "interactive": self.vol = self.fit_vol self.vol.display = True + timer_start = datetime.now() + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"-------\n" + f"DiffFit clustering starts: {timer_start}\n") + N_mol, N_quat, N_shift, N_iter, N_metric = e_sqd_log.shape self.e_sqd_log = e_sqd_log.reshape([N_mol, N_quat * N_shift, N_iter, N_metric]) - self.e_sqd_clusters_ordered = cluster_and_sort_sqd_fast(self.e_sqd_log, mol_centers, + self.e_sqd_clusters_ordered = cluster_and_sort_sqd_fast(self.e_sqd_log, self.settings.clustering_shift_tolerance, self.settings.clustering_angle_tolerance, in_contour_threshold=self.settings.clustering_in_contour_threshold, - correlation_threshold=self.settings.clustering_correlation_threshold) + correlation_threshold=self.settings.clustering_correlation_threshold, + df_cid_threshold=self.settings.df_cid_threshold, + save_log=save_log, + log_path=log_path) + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"-------\n" + f"DiffFit clustering time elapsed: {datetime.now() - timer_start}\n") if self.e_sqd_clusters_ordered is None: self.session.logger.error("No result under these thresholds. Please decrease \"In contour threshold\" or \"Correlation threshold\" or rerun the fitting!") self.proxyModel = None return - self.model = TableModel(self.e_sqd_clusters_ordered, self.e_sqd_log, mol_paths) + # ======= Calculate Q-scores + + timer_start = datetime.now() + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"-------\n" + f"DiffFit Q-scores calculation starts: {timer_start}\n") + + q_scores_np = None + q_shells_mode = "Full" + + q_scores_points_per_shell = 8 + q_scores_max_rad = 2.0 + q_scores_step = 0.1 + q_shell_radii = np.arange(0, q_scores_max_rad + q_scores_step / 2, q_scores_step) + + q_shell_coords_torch_list = [] + q_shell_radii_np_list = [] + + from .q_shells import generate_q_shells, generate_q_shells_simple, q_scores_for_clusters + + q_shell_generator = None + q_shells_ext = "" + if q_shells_mode == "Full": + q_shell_generator = generate_q_shells + q_shells_ext = "centered_q_shells.full.npz" + elif q_shells_mode == "Simple": + q_shell_generator = generate_q_shells_simple + q_shells_ext = "centered_q_shells.simple.npz" + + for mol_path in mol_paths: + mol_basename = Path(mol_path).stem + mol_folder = os.path.dirname(mol_path) + q_shells_filepath = os.path.join(mol_folder, f"{mol_basename}.{q_shells_ext}") + + if os.path.exists(q_shells_filepath): + q_shells_np = np.load(q_shells_filepath) + q_shell_coords = q_shells_np['q_shell_coords'] + q_shell_radii = q_shells_np['radii'] # Not atomic warning: q_scores_points_per_shell not saved!!! + else: + q_shell_coords, q_shell_radii = generate_q_shells_wrapper(q_shell_generator, + mol_path, + q_shells_ext, + self.session) + + q_shell_coords = torch.tensor(q_shell_coords, device=self._device.currentText()).float() + q_shell_coords = q_shell_coords.reshape([-1, 3]) + + q_shell_coords_torch_list.append(q_shell_coords) + q_shell_radii_np_list.append(q_shell_radii) + + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"Q-scores prep time elapsed: {datetime.now() - timer_start}\n" + f"-------\n") + + q_scores_np = q_scores_for_clusters(q_shell_coords_torch_list, + q_shell_radii_np_list, + self.vol, + self.e_sqd_clusters_ordered, + self.e_sqd_log, + device=self._device.currentText(), + save_log=save_log, + log_path=log_path) + + self.candidates_id.setText(", ".join(map(str, calculate_candidate_indices(q_scores_np) + 1))) + + if save_log: + with open(log_path, "a") as log_file: + log_file.write(f"-------\n" + f"DiffFit Q-scores calculation time elapsed: {datetime.now() - timer_start}\n" + f"-------\n\n") + + q_score_column = 3 + self.e_sqd_clusters_ordered = np.insert(self.e_sqd_clusters_ordered, q_score_column, q_scores_np, axis=1) + self.e_sqd_clusters_ordered = self.e_sqd_clusters_ordered[self.e_sqd_clusters_ordered[:, q_score_column].argsort()[::-1]] + + # ======= Create fit results table + self.model = TableModel(self.e_sqd_clusters_ordered, self.e_sqd_log, mol_paths, mol_num_atoms) self.proxyModel = QSortFilterProxyModel() self.proxyModel.setSourceModel(self.model) self.view.setModel(self.proxyModel) self.view.setSortingEnabled(True) - self.view.sortByColumn(0, Qt.AscendingOrder) + self.view.sortByColumn(2, Qt.DescendingOrder) self.view.reset() self.view.show() - self.stats.setText("stats: {0} entries".format(self.model.rowCount())) + self.stats.setText("stats: {0} entries".format(self.model.rowCount())) self.mol_paths = mol_paths self.cluster_idx = 0 + def _create_volume_conv_list(self, vol, smooth_by, smooth_loops, session, negative_space_value=-0.5): # From here on, there are three strategies for utilizing gaussian smooth # 1. with increasing sDev on the same input volume @@ -1472,17 +1660,18 @@ def single_fit_button_clicked(self): f"Sim-map resolution: {self._single_fit_res.value()}\n" f"# positions: {self._single_fit_n_shifts.value()}\n" f"# rotations: {self._single_fit_n_quaternions.value()}\n" - f"Smooth by: {self._smooth_by.currentText()}\n" + f"Smooth by: \"{self._smooth_by.currentText()}\"\n" f"Smooth loops: {self._single_fit_gaussian_loops.value()}\n" - f"Kernel sizes: {self.smooth_kernel_sizes.text()}\n" - f"Gaussian mode: {self.Gaussian_mode}\n" + f"Kernel sizes: \"{self.smooth_kernel_sizes.text()}\"\n" + f"Gaussian mode: \"{self.Gaussian_mode}\"\n" + f"Fit atom mode: \"{self.fit_atom_mode}\"\n" f"-------\n") self.disable_spheres_clicked() single_fit_timer_start = datetime.now() - # Prepare mol anv vol + # Prepare mol and vol mol = self._object_menu.value self.fit_mol_list = [mol] @@ -1503,7 +1692,6 @@ def single_fit_button_clicked(self): # Apply the user's transformation and center mol from chimerax.geometry import Place - mol.atoms.transform(mol.position) mol_center = mol.atoms.coords.mean(axis=0) transform = Place(origin=-mol_center) mol.atoms.transform(transform) @@ -1533,7 +1721,7 @@ def single_fit_button_clicked(self): (_, _, self.mol_paths, - self.mol_centers, + self.mol_num_atoms, self.fit_result) = diff_fit( volume_conv_list, self.fit_vol.path, @@ -1564,11 +1752,14 @@ def single_fit_button_clicked(self): self._view_input_mode.setCurrentText("interactive") self._view_input_mode_changed() self.interactive_fit_result_ready = True - self.show_results(self.fit_result, self.mol_centers, self.mol_paths) + self.show_results(self.fit_result, self.mol_num_atoms, self.mol_paths, + save_log=_save_results, + log_path=f"{_out_dir}/log.log") self.tab_widget.setCurrentWidget(self.tab_view_group) self.select_table_item(0) + run(self.session, "view orient") timer_stop = datetime.now() print(f"\nDiffFit total time elapsed: {timer_stop - single_fit_timer_start}\n\n") @@ -1584,17 +1775,17 @@ def single_fit_button_clicked(self): def dependency_install_button_clicked(self): - if self.dependency_name.text() is "": + if self.dependency_name.text() == "": self.session.logger.error("You have to specify a package name.") return package_name = self.dependency_name.text() - if self.dependency_version.text() is not "": + if self.dependency_version.text() != "": package_name += f"=={self.dependency_version.text()}" cmd_list = ["install", package_name] - if self.dependency_index_url.text() is not "": + if self.dependency_index_url.text() != "": cmd_list.extend(["--index-url", self.dependency_index_url.text()]) cmd_list.extend([ @@ -1607,23 +1798,44 @@ def dependency_install_button_clicked(self): run_logged_pip(cmd_list, self.session.logger) + def q_shell_button_clicked(self, mode: str): + q_shell_generator = None + ext = "" + if mode == "Full": + from .q_shells import generate_q_shells + q_shell_generator = generate_q_shells + ext = "centered_q_shells.full.npz" + elif mode == "Simple": + from .q_shells import generate_q_shells_simple + q_shell_generator = generate_q_shells_simple + ext = "centered_q_shells.simple.npz" + + str_dir = self.q_shells_dir.text() + + for file_name in sorted(os.listdir(str_dir)): + mol_path = os.path.join(str_dir, file_name) + file_extension = Path(file_name).suffix.lower() + if file_extension in ['.pdb', '.cif']: + _, _ = generate_q_shells_wrapper(q_shell_generator, mol_path, ext, self.session) + + def sim_button_clicked(self): - output_dir = self.sim_out_dir.text() - if not os.path.exists(output_dir): - os.makedirs(output_dir) + str_dir = self.sim_dir.text() - sim_structures_dir = self.sim_dir.text() - for file_name in os.listdir(sim_structures_dir): - file_path = os.path.join(sim_structures_dir, file_name) - structure = run(self.session, f'open {file_path}')[0] - structure_basename = file_name.split('.')[0] + for file_name in sorted(os.listdir(str_dir)): + file_path = os.path.join(str_dir, file_name) + file_extension = Path(file_name).suffix.lower() + if file_extension in ['.pdb', '.cif']: + structure = run(self.session, f'open {file_path}')[0] + structure_basename = Path(file_name).stem - mrc_filename = f"{structure_basename}.mrc" - mrc_filepath = os.path.join(output_dir, mrc_filename) + mrc_filename = f"{structure_basename}.mrc" + mrc_filepath = os.path.join(str_dir, mrc_filename) - vol = run(self.session, f'molmap #{structure.id[0]} {self.sim_resolution.value()} gridSpacing 1.0') - run(self.session, f"save {mrc_filepath} #{vol.id[0]}") - run(self.session, f"close #{vol.id[0]}") + vol = run(self.session, f'molmap #{structure.id[0]} {self.sim_resolution.value()} gridSpacing 1.0') + run(self.session, f"save {mrc_filepath} #{vol.id[0]}") + run(self.session, f"close #{vol.id[0]}") + run(self.session, f"close #{structure.id[0]}") def split_button_clicked(self): @@ -1632,7 +1844,7 @@ def split_button_clicked(self): os.makedirs(output_dir) structure = self._split_model.value - structure_basename = os.path.basename(structure.filename).split('.')[0] + structure_basename = Path(structure.filename).stem chain_id_name_list = [] for chain in structure.chains: @@ -1674,15 +1886,15 @@ def run_button_clicked(self): f"Disk mode\n" f"Target Volume: {self.settings.target_vol_path}\n" f"Structures Folder: {self.settings.structures_directory}\n" - f"Sim-map Folder: {self.settings.structures_sim_map_dir}\n" f"Target Surface Threshold: {self.settings.target_surface_threshold}\n" f"-------\n" f"# positions: {self.settings.N_shifts}\n" f"# rotations: {self.settings.N_quaternions}\n" - f"Gaussian mode: {self.Gaussian_mode}\n" + f"Gaussian mode: \"{self.Gaussian_mode}\"\n" + f"Fit atom mode: \"{self.fit_atom_mode}\"\n" f"Conv. loops: {self.settings.conv_loops}\n" - f"Conv. kernel sizes: {self.settings.conv_kernel_sizes}\n" - f"Conv. weights: {self.settings.conv_weights}\n" + f"Conv. kernel sizes: \"{self.settings.conv_kernel_sizes}\"\n" + f"Conv. weights: \"{self.settings.conv_weights}\"\n" f"-------\n") @@ -1698,13 +1910,12 @@ def run_button_clicked(self): (target_vol_path, target_surface_threshold, mol_paths, - mol_centers, + mol_num_atoms, e_sqd_log) = diff_atom_comp( target_vol_path=self.settings.target_vol_path, target_surface_threshold=self.settings.target_surface_threshold, min_cluster_size=self.settings.min_cluster_size, structures_dir=self.settings.structures_directory, - structures_sim_map_dir=self.settings.structures_sim_map_dir, fit_atom_mode=self.fit_atom_mode, Gaussian_mode=self.Gaussian_mode, N_shifts=self.settings.N_shifts, @@ -1728,18 +1939,20 @@ def run_button_clicked(self): f"DiffFit optimization time elapsed: {timer_stop - timer_start}\n") # copy the directories - self.target_vol.setText(self.settings.target_vol_path) self.dataset_folder.setText("{0}".format(self.settings.output_directory)) #print(self.settings) # output is tensor, convert to numpy self.show_results(e_sqd_log.detach().cpu().numpy(), - mol_centers, + mol_num_atoms, mol_paths, target_vol_path, - target_surface_threshold) + target_surface_threshold, + save_log=True, + log_path=f"{_out_dir}/log.log") self.tab_widget.setCurrentWidget(self.tab_view_group) self.select_table_item(0) + run(self.session, "view orient") timer_stop = datetime.now() print(f"\nDiffFit total time elapsed: {timer_stop - disk_fit_timer_start}\n\n") @@ -1756,7 +1969,7 @@ def load_button_clicked(self): if self.fit_input_mode == "interactive": if self.interactive_fit_result_ready: self.show_results(self.fit_result, - self.mol_centers, + self.mol_num_atoms, [self.mol.filename], self.fit_vol.path, self.fit_vol.maximum_surface_level) @@ -1776,14 +1989,16 @@ def load_button_clicked(self): return print("loading data...") - fit_res = np.load("{0}\\fit_res.npz".format(datasetoutput)) + fit_res = np.load("{0}/fit_res.npz".format(datasetoutput)) target_vol_path = fit_res['target_vol_path'] target_surface_threshold = fit_res['target_surface_threshold'] mol_paths = fit_res['mol_paths'] - mol_centers = fit_res['mol_centers'] + mol_num_atoms = fit_res['mol_num_atoms'] opt_res = fit_res['opt_res'] - self.show_results(opt_res, mol_centers, mol_paths, target_vol_path, target_surface_threshold) + self.show_results(opt_res, mol_num_atoms, mol_paths, target_vol_path, target_surface_threshold, + save_log=True, + log_path=f"{datasetoutput}/log.log") self.select_table_item(0) run(self.session, f"view orient") @@ -1806,12 +2021,12 @@ def save_structure_button_clicked(self): def save_structure(self, targetpath, ext): if len(targetpath) > 0 and self.mol: - run(self.session, "save '{0}.{1}' models #{2}".format(targetpath, ext, self.mol.id[0])) + run(self.session, "save '{0}' models #{1}".format(targetpath, self.mol.id[0])) def save_working_volume(self, targetpath, ext): if len(targetpath) > 0 and self.vol: - run(self.session, "save '{0}.{1}' models #{2}".format(targetpath, ext, self.vol.id[0])) + run(self.session, "save '{0}' models #{1}".format(targetpath, self.vol.id[0])) def simulate_volume_clicked(self): res = self.simulate_volume_resolution.value() @@ -1820,9 +2035,33 @@ def simulate_volume_clicked(self): res) elif self.fit_input_mode == "interactive": from chimerax.map.molmap import molecule_map - self.mol_vol = molecule_map(self.session, self.mol.atoms, res, grid_spacing=self.vol.data_origin_and_step()[1][0] / 3) + self.mol_vol = molecule_map(self.session, self.mol.atoms, res, grid_spacing=self.vol.data.step[0] / 3.0) return + + def save_candidates(self): + """Save candidates action triggered by the Save button.""" + try: + candidates = self.candidates_id.text() + candidate_ids = candidates.split(",") + + for candidate_id in candidate_ids: + # Strip any leading/trailing spaces (if any) + candidate_id = candidate_id.strip() + self.select_table_item(int(candidate_id) - 1) + + base_name, ext = os.path.splitext(self.mol.name) + base_file_path = f"{self.candidates_folder.text()}/{base_name}" + counter = 1 + file_path = f"{base_file_path}_{counter}{ext}" + while os.path.exists(file_path): + file_path = f"{base_file_path}_{counter}{ext}" + counter += 1 + + run(self.session, f"save {file_path} models #{self.mol.id[0]}") + self.session.logger.info(f"Saved candidate_id: {candidates}") + except: + self.session.logger.error("Failed to parse the candidates id field.") def zero_density_button_clicked(self): if self.vol is None: