Skip to content

Commit

Permalink
add filament priors with rlnTomoSubtomogramRot/Tilt/Psi set to partic…
Browse files Browse the repository at this point in the history
…les pre-rotated -90 degrees around the Y axis and account for that 90 degree rotation in rlnAngleRot/Tilt/Psi and rlnAngleTilt/PsiPrior.
  • Loading branch information
alisterburt committed May 9, 2023
1 parent b2db1fa commit 13da3f0
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ relion_tomo_exclude_tilt_images = "tomography_python_programs.exclude_tilt_image
relion_tomo_align_tilt_series = "tomography_python_programs.align_tilt_series:cli"
relion_tomo_view = "tomography_python_programs.view:cli"
relion_tomo_pick = "tomography_python_programs.pick:cli"
relion_tomo_derive_particle_poses = "tomography_python_programs.derive_particle_poses:cli"
relion_tomo_get_particle_poses = "tomography_python_programs.get_particle_poses:cli"
relion_tomo_denoise = "tomography_python_programs.denoise:cli"


Expand Down
120 changes: 120 additions & 0 deletions remove_duplicates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from enum import Enum
from functools import lru_cache
from pathlib import Path

import napari
import numpy as np
import pandas as pd
import starfile
import magicgui
from scipy.spatial import KDTree

from magicgui.experimental import guiclass
from magicgui.widgets import Button

STAR_FILE = 'run_it007_data.star'
star = starfile.read(STAR_FILE)

df = star['particles']

grouped = df.groupby('rlnTomoName')
TiltSeriesId = Enum('TiltSeriesIds', {ts_id: ts_id for ts_id in grouped.groups.keys()})
for ts_id in TiltSeriesId:
first_tilt_series = ts_id
break

viewer = napari.Viewer(ndisplay=3)
widget = magicgui.widgets.create_widget(annotation=TiltSeriesId)


@guiclass
class ParameterClass:
tilt_series_id: TiltSeriesId = first_tilt_series
max_distance: int = 1
output: Path = 'deduplicated.star'


parameters = ParameterClass()


def get_zyx(tilt_series_id: TiltSeriesId) -> np.ndarray:
df = grouped.get_group(tilt_series_id.value)
zyx = df[['rlnCoordinateZ', 'rlnCoordinateY', 'rlnCoordinateX']].to_numpy()
if 'rlnOriginZAngst' in df.columns:
shifts = df[
['rlnOriginZAngst', 'rlnOriginYAngst', 'rlnOriginXAngst']].to_numpy()
zyx -= shifts
return zyx


@parameters.events.max_distance.connect
@parameters.events.tilt_series_id.connect
def remove_duplicates():
points = get_zyx(parameters.tilt_series_id)
points = _collapse_knn(
points=points,
max_distance=parameters.max_distance,

)
if 'collapsed points' not in viewer.layers:
viewer.add_points(points, size=40, name='collapsed points')
else:
viewer.layers['collapsed points'].data = points
viewer.camera.center = np.mean(points, axis=0)


def _collapse_knn(
points: np.ndarray,
max_distance: float,
k: int = 15,
) -> np.ndarray:
tree = KDTree(data=points)
distance, idx = tree.query(points, k=k, distance_upper_bound=max_distance)

# remove distances to self
distance, idx = distance[:, 1:], idx[:, 1:]

# collapse knn up to distance
idx_removed = []
collapsed_points = []
for i, (_distance, _idx) in enumerate(zip(distance, idx)):
if i in idx_removed:
continue
valid_idx = _idx[_idx < len(points)]
if len(valid_idx) == 0:
collapsed_points.append(points[i])
continue
point_group = points[valid_idx]
collapsed_points.append(point_group.mean(axis=0))
idx_removed.extend(valid_idx)
return np.stack(collapsed_points, axis=0)


def save_star_file():
path = parameters.output
dfs = []
for ts_id in TiltSeriesId:
zyx = get_zyx(ts_id)
tree = KDTree(data=zyx)
zyx_final = _collapse_knn(zyx, max_distance=parameters.max_distance)
_, idx = tree.query(zyx_final, k=1)
df = grouped.get_group(ts_id.value)
df = df.iloc[idx]
dfs.append(df)
print(f'deduplicated {ts_id.value}')
df = pd.concat(dfs)
new_star = star.copy()
new_star['particles'] = df
starfile.write(star, path, overwrite=True)
print(f'saving particles to {path}')

pass


button = Button(text='save STAR file')
button.clicked.connect(save_star_file)
parameters.gui.append(button)

viewer.window.add_dock_widget(parameters.gui, area='left', name='collapse kNN')
remove_duplicates()
napari.run()
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

from .particles import combine_particle_annotations
from .spheres import derive_poses_on_spheres
from .filaments import derive_poses_along_filament_backbones
from .filaments import get_poses_along_filament_backbones

from ._cli import cli
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import einops
import numpy as np
import pandas as pd
import starfile
Expand All @@ -15,7 +14,7 @@

@cli.command(name=COMMAND_NAME, no_args_is_help=True)
@relion_pipeline_job
def derive_poses_along_filament_backbones(
def get_poses_along_filament_backbones(
tilt_series_star_file: pathlib.Path = typer.Option(
..., help='tilt-series STAR file containing tomogram'
),
Expand All @@ -28,15 +27,9 @@ def derive_poses_along_filament_backbones(
spacing_angstroms: float = typer.Option(
..., help="spacing between particles along filaments in angstroms."
),
twist_degrees: float = typer.Option(
..., help="twist between particles in degrees."
),
add_helical_priors: bool = typer.Option(
False, help="Whether to extract rotated particles and add helical priors."
),
filament_polarity_known: bool = typer.Option(
True, help="Whether filament polarity from annotations should be fixed "
"during helical reconstruction"
"during refinement."
)
):
global_df = starfile.read(tilt_series_star_file)
Expand All @@ -55,10 +48,16 @@ def derive_poses_along_filament_backbones(
# derive equidistant poses along length of filament
path = Path(control_points=xyz)
pose_sampler = path_samplers.HelicalPoseSampler(
spacing=spacing_angstroms / pixel_size, twist=twist_degrees
spacing=spacing_angstroms / pixel_size, twist=0
)
poses = pose_sampler.sample(path)
eulers = R.from_matrix(poses.orientations).inv().as_euler(

# rot/psi are coupled when tilt==0,
# pre-rotate particles such that they have tilt=90 relative to a reference
# filament aligned along the z-axis
rotated_basis = R.from_euler('y', angles=-90, degrees=True).as_matrix()
rotated_orientations = poses.orientations @ rotated_basis
rotated_eulers = R.from_matrix(rotated_orientations).inv().as_euler(
seq='ZYZ', degrees=True,
)

Expand All @@ -69,16 +68,6 @@ def derive_poses_along_filament_backbones(
distances = np.linalg.norm(differences, axis=1)
total_length = np.sum(distances)

# rot/psi are coupled when tilt==0,
# pre-rotate particles such that they have tilt=90 relative to a reference
# filament aligned along the z-axis
if add_helical_priors is True:
rotated_basis = R.from_euler('x', angles=90, degrees=True).as_matrix()
rotated_orientations = poses.orientations @ rotated_basis
eulers = R.from_matrix(rotated_orientations).inv().as_euler(
seq='ZYZ', degrees=True,
)

# how far along the helix is each particle? in angstroms
total_length = total_length / pixel_size
distance_along_helix = np.linspace(0, 1, num=len(poses)) * total_length
Expand All @@ -90,22 +79,25 @@ def derive_poses_along_filament_backbones(
'rlnCoordinateX': poses.positions[:, 0],
'rlnCoordinateY': poses.positions[:, 1],
'rlnCoordinateZ': poses.positions[:, 2],
'rlnAngleRot': eulers[:, 0],
'rlnAngleTilt': eulers[:, 1],
'rlnAnglePsi': eulers[:, 2],
'rlnTomoSubtomogramRot': rotated_eulers[:, 0],
'rlnTomoSubtomogramTilt': rotated_eulers[:, 1],
'rlnTomoSubtomogramPsi': rotated_eulers[:, 2],
}
dfs.append(pd.DataFrame(data))
df = pd.concat(dfs)

if add_helical_priors is True:
rot_prior, tilt_prior, psi_prior = R.from_matrix(rotated_basis).inv().as_euler(
seq='ZYZ', degrees=True
)
df['rlnAngleTiltPrior'] = [tilt_prior] * len(df)
df['rlnAnglePsiPrior'] = [psi_prior] * len(df)
# add priors on orientations
rot_prior, tilt_prior, psi_prior = R.from_matrix(rotated_basis).inv().as_euler(
seq='ZYZ', degrees=True
)
df['rlnAngleRot'] = [rot_prior] * len(df)
df['rlnAngleTilt'] = [tilt_prior] * len(df)
df['rlnAnglePsi'] = [psi_prior] * len(df)
df['rlnAngleTiltPrior'] = [tilt_prior] * len(df)
df['rlnAnglePsiPrior'] = [psi_prior] * len(df)

if filament_polarity_known is False:
df['rlnAnglePsiFlipRatio'] = [0.5] * len(df)
if filament_polarity_known is False:
df['rlnAnglePsiFlipRatio'] = [0.5] * len(df)

# write output
output_file = output_directory / 'particles.star'
Expand Down
44 changes: 44 additions & 0 deletions vis_refinement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from enum import Enum

import napari
import numpy as np
import starfile
import magicgui

STAR_FILE = 'run_it007_data.star'
star = starfile.read(STAR_FILE)

df = star['particles']

grouped = df.groupby('rlnTomoName')
TiltSeriesId = Enum('TiltSeriesIds', {ts_id: ts_id for ts_id in grouped.groups.keys()})
for ts_id in TiltSeriesId:
first_tilt_series = ts_id
break

widget = magicgui.widgets.create_widget(annotation=TiltSeriesId)

viewer = napari.Viewer()
viewer.window.add_dock_widget(widget, area='left', name='tilt-series id')


@widget.changed.connect
def load_tilt_series(tilt_series_id: TiltSeriesId):
tilt_series_id = tilt_series_id.value
zyx = get_zyx(tilt_series_id)
if 'particle positions' not in viewer.layers:
viewer.add_points(zyx, name='particle positions', size=40)
else:
viewer.layers['particle positions'].data = zyx


def get_zyx(tilt_series_id: str) -> np.ndarray:
df = grouped.get_group(tilt_series_id)
zyx = df[['rlnCoordinateZ', 'rlnCoordinateY', 'rlnCoordinateX']].to_numpy()
if 'rlnOriginZAngst' in df.columns:
shifts = df[['rlnOriginZAngst', 'rlnOriginYAngst', 'rlnOriginXAngst']].to_numpy()
zyx -= shifts
return zyx

load_tilt_series(first_tilt_series)
napari.run()

0 comments on commit 13da3f0

Please sign in to comment.