Skip to content

Commit

Permalink
fix imports in GUI components
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt committed May 5, 2023
1 parent e63bd28 commit 8cc8bd7
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, tilt_image_file: Path):

class TiltImageListWidget(QListWidget):
"""QListWidget of QTiltImageItem"""

def __init__(self):
super().__init__()
self.images = []
Expand Down Expand Up @@ -64,5 +65,3 @@ def select_all(self):
def deselect_all(self):
for i in range(self.count()):
self.item(i).setCheckState(Qt.CheckState.Unchecked)


Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Iterable
from typing import List, Dict, Iterable

import mrcfile
import napari
Expand Down Expand Up @@ -30,12 +30,12 @@ class TiltSeriesBrowserWidget(QWidget):
tilt_series_changed: Signal = Signal()

def __init__(
self,
viewer: napari.Viewer,
tilt_series: GuiTiltSeriesSet,
cache_size: int,
*args,
**kwargs
self,
viewer: napari.Viewer,
tilt_series: GuiTiltSeriesSet,
cache_size: int,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.viewer = viewer
Expand Down Expand Up @@ -103,7 +103,8 @@ def on_tilt_series_change(self):
self._disconnect_worker_safe(worker)

if self.selected_tilt_series.name not in self._cache:
self.load_tilt_series_async(self.selected_tilt_series.name, update_viewer=True)
self.load_tilt_series_async(self.selected_tilt_series.name,
update_viewer=True)
else:
self.load_tilt_series_from_cache(self.selected_tilt_series.name)
self._remove_workers_for_uncached_tilt_series()
Expand Down Expand Up @@ -165,11 +166,13 @@ def _create_thread_worker(self, tilt_series_id: str) -> GeneratorWorker:
worker = _read_tilt_series(tilt_series_id, list(tilt_images))
return worker

def _connect_events_for_background_loading(self, worker: GeneratorWorker) -> GeneratorWorker:
def _connect_events_for_background_loading(self,
worker: GeneratorWorker) -> GeneratorWorker:
worker.yielded.connect(self._cache_tilt_series)
return worker

def _connect_events_for_gui_updates(self, worker: GeneratorWorker) -> GeneratorWorker:
def _connect_events_for_gui_updates(self,
worker: GeneratorWorker) -> GeneratorWorker:
worker.yielded.connect(self._update_tilt_series_in_viewer)
worker.finished.connect(self.tilt_series_changed.emit)
worker.finished.connect(self.preload_next_tilt_series)
Expand Down Expand Up @@ -214,7 +217,8 @@ def next_tilt_series(self, event=None):
@property
def background_workers(self) -> Iterable[GeneratorWorker]:
background_worker_keys = self._workers.keys() - {self.selected_tilt_series.name}
return (self._workers[tilt_series_id] for tilt_series_id in background_worker_keys)
return (self._workers[tilt_series_id] for tilt_series_id in
background_worker_keys)

def _update_list_from_combobox(self):
idx = self.tilt_series_combo_box.currentIndex()
Expand All @@ -227,14 +231,16 @@ def _update_combobox_from_list(self):
self.tilt_series_combo_box.blockSignals(False)


def _create_empty_tilt_series_data(tilt_image_files: List[Path], dtype=np.float32) -> np.ndarray:
def _create_empty_tilt_series_data(tilt_image_files: List[Path],
dtype=np.float32) -> np.ndarray:
with mrcfile.open(tilt_image_files[0], header_only=True) as mrc:
tilt_series_shape = (len(tilt_image_files), mrc.header.ny, mrc.header.nx)
return np.zeros(shape=tilt_series_shape, dtype=dtype)


@thread_worker(progress={'total': 40}, start_thread=False)
def _read_tilt_series(tilt_series_id: str, tilt_image_files: List[Path]) -> LazyTiltSeriesData:
def _read_tilt_series(tilt_series_id: str,
tilt_image_files: List[Path]) -> LazyTiltSeriesData:
tilt_series = _create_empty_tilt_series_data(tilt_image_files, dtype=np.float16)
lazy_tilt_series_data = LazyTiltSeriesData(
name=tilt_series_id, data=tilt_series, last_loaded_index=None, n_images_loaded=0
Expand Down
32 changes: 18 additions & 14 deletions src/tomography_python_programs/_qt/components/tomogram_browser.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import mrcfile
import napari
import mrcfile
import numpy as np
from psygnal import Signal
from qtpy.QtWidgets import QWidget, QVBoxLayout, QLabel, QSizePolicy
from lru import LRU

from ..._metadata_models.gui.tilt_series_set import GuiTiltSeriesSet as GuiTiltSeriesSet
from ..._metadata_models.gui.tilt_series_set import GuiTiltSeriesSet
from ..._metadata_models.gui.tilt_series import GuiTiltSeries
from .tilt_series_list import TiltSeriesListWidget

IMAGE_LAYER_NAME = 'tomogram'


class TomogramBrowserWidget(QWidget):
changing_tomogram: Signal = Signal()
tomogram_changed: Signal = Signal()
image_layer: napari.layers.Image

def __init__(
self,
viewer: napari.Viewer,
tilt_series: GuiTiltSeriesSet,
cache_size: int,
*args,
**kwargs
self,
viewer: napari.Viewer,
tilt_series: GuiTiltSeriesSet,
cache_size: int,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.viewer = viewer
Expand Down Expand Up @@ -48,6 +51,10 @@ def __init__(
self.next_tilt_series = self.viewer.bind_key(']', self.next_tilt_series)
self._on_tomogram_selection_change()

@property
def image_layer(self) -> napari.layers.Image:
return self.viewer.layers[IMAGE_LAYER_NAME]

@property
def selected_tilt_series(self) -> GuiTiltSeries:
return self._selected_tilt_series
Expand All @@ -64,11 +71,11 @@ def _on_tomogram_selection_change(self):

def _update_tomogram_in_viewer(self, tomogram: np.ndarray):
if 'tomogram' in self.viewer.layers:
self.viewer.layers['tomogram'].data = tomogram
self.viewer.layers[IMAGE_LAYER_NAME].data = tomogram
else:
self.viewer.add_image(
data=tomogram,
name='tomogram',
name=IMAGE_LAYER_NAME,
depiction='plane',
plane={'thickness': 1},
rendering='minip',
Expand All @@ -87,15 +94,13 @@ def _load_tomogram(self, tilt_series_id: str, add_to_viewer: bool) -> None:
self._cache[tilt_series_id] = tomogram
if add_to_viewer:
self._update_tomogram_in_viewer(tomogram)
self.tomogram_changed.emit()

def _load_tomogram_from_cache(self, tilt_series_id: str, add_to_viewer: bool):
if add_to_viewer is True:
self._update_tomogram_in_viewer(self._cache[tilt_series_id])
self._on_tomogram_loaded()

def _on_tomogram_loaded(self):
layer = self.viewer.layers['tomogram']
layer = self.viewer.layers[IMAGE_LAYER_NAME]
layer.depiction = 'plane'
layer.plane.position = np.array(layer.data.shape) / 2
layer.plane.normal = (1, 0, 0)
Expand All @@ -118,4 +123,3 @@ def previous_tilt_series(self, event=None):

def next_tilt_series(self, event=None):
self.tilt_series_list_widget.next()

0 comments on commit 8cc8bd7

Please sign in to comment.