diff --git a/Smartscope.egg-info/PKG-INFO b/Smartscope.egg-info/PKG-INFO index 04cf662d..5a6caa85 100644 --- a/Smartscope.egg-info/PKG-INFO +++ b/Smartscope.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: Smartscope -Version: 0.9.2rc1 +Version: 0.9.3.dev0 Summary: Smartscope module for automatic CryoEM grid screening Author: Jonathan Bouvette Author-email: jonathan.bouvette@nih.gov diff --git a/Smartscope/__init__.py b/Smartscope/__init__.py index e00b937c..aa50dd62 100755 --- a/Smartscope/__init__.py +++ b/Smartscope/__init__.py @@ -7,7 +7,7 @@ import sys version_file = Path(__file__).parents[1].resolve() / 'VERSION' -__version__ = version_file.read_text() +__version__ = version_file.read_text().strip() LOGLEVEL = os.getenv('LOGLEVEL') if os.getenv('LOGLEVEL') is not None else 'DEBUG' diff --git a/Smartscope/core/config.py b/Smartscope/core/config.py index fc2065d5..b1e21fb9 100755 --- a/Smartscope/core/config.py +++ b/Smartscope/core/config.py @@ -1,18 +1,33 @@ # from asyncio import protocols import os -from typing import List +from typing import List, Dict, Optional, Union from pathlib import Path import yaml import logging import importlib import sys -from Smartscope.lib.Datatypes.base_plugin import Finder, Classifier, Selector +from Smartscope.lib.Datatypes.base_plugin import BaseFeatureAnalyzer, Finder, Classifier, Selector from Smartscope.lib.Datatypes.base_protocol import BaseProtocol logger = logging.getLogger(__name__) + + +def register_plugin(data): + if not 'pluginClass' in data.keys(): + out_class = Finder + if 'Classifier' in data['targetClass']: + out_class = Classifier + if data['targetClass'] == ['Selector']: + out_class = Selector + else: + split = data['pluginClass'].split('.') + module = importlib.import_module('.'.join(split[:-1])) + out_class = getattr(module, split[-1]) + return out_class + def register_plugins(directories, factory): for directory in directories: for file in directory.glob('*.yaml'): @@ -20,18 +35,9 @@ def register_plugins(directories, factory): with open(file) as f: data = yaml.safe_load(f) - if not 'pluginClass' in data.keys(): - out_class = Finder - if 'Classifier' in data['targetClass']: - out_class = Classifier - if data['targetClass'] == ['Selector']: - out_class = Selector - else: - split = data['pluginClass'].split('.') - module = importlib.import_module('.'.join(split[:-1])) - out_class = getattr(module, split[-1]) + out_class = register_plugin(data) - factory[data['name']] = out_class.parse_obj(data) + factory[data['name']] = out_class.model_validate(data) def get_active_plugins_list(external_plugins_directory,external_plugins_list): with open(external_plugins_list,'r') as file: @@ -43,21 +49,81 @@ def register_protocols(directories:List[Path], factory): logger.debug(f'Registering protocol {file}') with open(file) as f: data = yaml.safe_load(f) - factory[data['name']] = BaseProtocol.parse_obj(data) + factory[data['name']] = BaseProtocol.model_validate(data) def register_external_plugins(external_plugins_list,plugins_factory,protocols_factory): for path in external_plugins_list: - sys.path.insert(0, str(path)) register_plugins([path/'smartscope_plugin'/'plugins'],plugins_factory) register_protocols([path/'smartscope_plugin'/'protocols'],protocols_factory) - sys.path.remove(str(path)) def get_protocol_commands(external_plugins_list): from Smartscope.core.protocol_commands import protocolCommandsFactory for path in external_plugins_list: - sys.path.insert(0, str(path)) - module = importlib.import_module('smartscope_plugin.protocol_commands') + module = importlib.import_module(path.name +'.smartscope_plugin.protocol_commands') protocol_commands= getattr(module,'protocolCommandsFactory') - sys.path.remove(str(path)) protocolCommandsFactory.update(protocol_commands) - return protocolCommandsFactory \ No newline at end of file + return protocolCommandsFactory + +class PluginDoesnNotExistError(Exception): + + def __init__(self, plugin, plugins:Optional[Dict]= None) -> None: + if plugins is not None: + message = f'Plugin \'{plugin}\' does not exist. Available plugins are: {list(plugins.keys())}' + super().__init__(message) + +class PluginFactory: + _plugins_directory: Path + _external_plugins_directory: Optional[Path] = None + _external_plugins_list_file: Optional[Path] = None + _plugins_list: List[Path] = [] + _plugins_data: Dict[str, Dict] = {} + _factory: Dict[str, BaseFeatureAnalyzer] = {} + + def __init__(self, plugins_directory: Union[str,Path], external_plugins_list_file: Optional[Union[str,Path]]=None, external_plugins_directory: Optional[Union[str,Path]]=None) -> None: + self._plugins_directory = Path(plugins_directory) + if all([external_plugins_list_file is not None, external_plugins_directory is not None]): + self._external_plugins_list_file = Path(external_plugins_list_file) + self._external_plugins_directory = Path(external_plugins_directory) + # self.parse_plugins_directory() + # self.read_plugins_files() + # self.register_plugins() + + def parse_plugins_directory(self) -> None: + self._plugins_list += list(self._plugins_directory.glob('*.yaml')) + if self._external_plugins_directory is not None: + pass + + def read_plugins_files(self) -> None: + for file in self._plugins_list: + logger.debug(f'Reading plugin {file}') + with open(file) as f: + data = yaml.safe_load(f) + self._plugins_data[data['name']] = data + + def register_plugin(self, name:str) -> None: + data = self._plugins_data[name] + out_class = register_plugin(data) + logger.debug(f'Registering plugin {name}') + self._factory[name] = out_class.model_validate(data) + + def register_plugins(self): + for name in self._plugins_data.keys(): + self.register_plugin(name) + + def get_plugin(self, name) -> BaseFeatureAnalyzer: + if (plugin := self._factory.get(name)) is not None: + logger.debug(f'Getting plugin {name}') + return plugin + if self._plugins_list.get(name) is not None: + logger.debug(f'Plugin {name} not registered, registering it now.') + return self._register_plugin(name) + raise PluginDoesnNotExistError(name, self._factory) + + def get_plugins(self): + return self._factory + + def reload_plugins(self): + self._plugins_list = [] + self._plugins_data = {} + self._factory = {} + self.parse_plugins_directory(PLUGINS_DIRECTORY) \ No newline at end of file diff --git a/Smartscope/core/data_manipulations.py b/Smartscope/core/data_manipulations.py new file mode 100644 index 00000000..58500c14 --- /dev/null +++ b/Smartscope/core/data_manipulations.py @@ -0,0 +1,145 @@ + +from typing import List, Optional, Dict +import logging +import random +from copy import copy +from Smartscope.lib.image.target import Target +from smartscope_connector.Datatypes.querylist import QueryList +from Smartscope.core.selector_sorter import SelectorSorter, SelectorValueParser, initialize_selector +from Smartscope.core.settings.worker import PLUGINS_FACTORY +import numpy as np + +from smartscope_connector import models + +logger = logging.getLogger(__name__) + +def create_target(target:models.target.Target, model:models.target.Target, finder:str, classifier:Optional[str]=None, start_number:int=0, **extra_fields): + target_dict = target.to_dict() + context = dict(number=start_number) + context['finders'] = [models.target_label.Finder.model_validate(target_dict | dict(method_name=finder))] + if classifier is not None: + context['classifiers'] = [models.target_label.Classifier(method_name=classifier,label=target.quality)] + data = target_dict | context | extra_fields + obj = model.model_validate(data) + return obj + +def add_targets(targets:List[Target], model:Target, finder:str, classifier:Optional[str]=None, start_number:int=0, **extra_fields): + output = [] + for ind, target in enumerate(targets): + number = ind + start_number + output.append(create_target(target, model, finder, classifier, number, **extra_fields)) + return QueryList(output) + + +def get_target_methods(targets, method_type:['selectors','finders','classifiers']='selectors'): + def get_selector_methods_names(target): + items = getattr(target, method_type) + if isinstance(items, list): + return map(lambda x: x.method_name, items) + return map(lambda x: x.method_name ,list(items.all())) + + return set().union(*map(get_selector_methods_names, targets)) + + +def randomized_choice(filtered_set: set, n: int): + choices = [] + while n >= len(filtered_set): + choices += list(filtered_set) + n -= len(filtered_set) + logger.debug(f'More choices than length of filtered set, choosing one of each {choices}. {n} left to randomly choose from.') + for i in range(n): + + choice = random.choice(list(filtered_set)) + logger.debug(f'For {i}th choice, choosing {choice} from {filtered_set}.') + choices.append(choice) + filtered_set.remove(choice) + return choices + +def choose_get_index(lst, value): + indices = [i for i, x in enumerate(lst) if x == value] + if indices == []: + return None + choice = random.choice(indices) + del lst[choice] + return choice + + +def filter_out_of_range(target): + return 0 if target.is_out_of_range() else 1 + + +def filter_targets(parent, targets): + classifiers = get_target_methods(targets, 'classifiers') + selectors = get_target_methods(targets, 'selectors') + + ##Filter out of range targets + filtered = list(map(filter_out_of_range, targets)) + logger.debug(f'Filtering {len(filtered)} targets.') + + for classifier in classifiers: + for ind, target in enumerate(targets): + if filtered[ind] == 0: + continue + t_classifiers = target.classifiers + if not isinstance(t_classifiers, list): + t_classifiers = list(t_classifiers.all()) + label = next(filter(lambda x: x.method_name == classifier, t_classifiers),None) + if label is None: + continue + if PLUGINS_FACTORY[classifier].classes[label.label].value <= 0: + filtered[ind] = 0 + continue + + filtered = np.array(filtered) + for selector in selectors: + sorter = initialize_selector(parent.grid_id, selector, targets) + filtered *= np.array(sorter.classes) + logger.debug(f'Filtered classes against classifiers {classifiers} and selectors {selectors}: {filtered}') + + return filtered.tolist() + +def apply_filter(targets, filtered): + return [target for target, filt in zip(targets, filtered) if filt > 0] + +def select_random_areas(targets, filtered, n): + filtered_set = set(filtered) + if filtered_set == {0}: + return [] + filtered_set.discard(0) + logger.debug(f'Selecting from {len(filtered_set)} subsets.') + choices = randomized_choice(filtered_set, n) + logger.debug(f'Randomized choices: {choices}') + output = [] + for choice in choices: + ind = choose_get_index(filtered, choice) + if ind is None: + break + output.append(targets[ind]) + return output + +def select_n_areas(parent, n, is_bis=False): + additional_filters = dict() + if is_bis: + additional_filters['bis_type'] = 'center' + additional_filters['status__isnull'] = True + targets = list(parent.targets.filter(**additional_filters)) + filtered= filter_targets(parent, targets) + if n <=0: + return apply_filter(targets, filtered) + return select_random_areas(targets, filtered, n) + +def set_or_update_refined_finder(instance, stage_x, stage_y, stage_z): + refined = next(filter(lambda x: x.method_name == 'Recentering',instance.finders), None) + if refined is None: + original_finder = instance.finders[0] + refined = models.target_label.Finder(method_name='Recentering', + x= original_finder.x, + y= original_finder.y, + stage_x=stage_x, + stage_y=stage_y, + stage_z=stage_z,) + instance.finders.insert(0,refined) + return instance + index = instance.finders.index(refined) + instance.finders[index].set_stage_position(x=stage_x, y=stage_y, z=stage_z) + return instance diff --git a/Smartscope/core/db_manipulations.py b/Smartscope/core/db_manipulations.py index faae8a08..b5f7b6b6 100755 --- a/Smartscope/core/db_manipulations.py +++ b/Smartscope/core/db_manipulations.py @@ -132,11 +132,11 @@ def viewer_only(user): return False -def group_holes_for_BIS(hole_models, max_radius=4, min_group_size=1, queue_all=False, iterations=500, score_weight=2): +def group_holes_for_BIS(hole_models, max_radius=4, min_group_size=1, iterations=500, score_weight=2): if len(hole_models) == 0: return hole_models logger.debug( - f'grouping params, max radius = {max_radius}, min group size = {min_group_size}, queue all = {queue_all}, max iterations = {iterations}, score_weight = {score_weight}') + f'grouping params, max radius = {max_radius}, min group size = {min_group_size}, max iterations = {iterations}, score_weight = {score_weight}') # Extract coordinated for the holes prefetch_related_objects(hole_models, 'finders') coords = [] @@ -198,10 +198,6 @@ def group_holes_for_BIS(hole_models, max_radius=4, min_group_size=1, queue_all=F center = hole_models[i] group_name = center.generate_bis_group_name() - if queue_all: - center.selected = True - center.status = 'queued' - bis = g[g != i] for item in bis: i = hole_models[item] diff --git a/Smartscope/core/frames.py b/Smartscope/core/frames.py new file mode 100644 index 00000000..7f66a852 --- /dev/null +++ b/Smartscope/core/frames.py @@ -0,0 +1,36 @@ +import yaml +import re +import logging +from Smartscope.core.models import AutoloaderGrid +from Smartscope.core.settings.worker import SMARTSCOPE_CUSTOM_CONFIG + +logger = logging.getLogger(__name__) + + +def get_frames_prefix(grid:AutoloaderGrid): + detector = grid.parent.detector_id + custom_paths = SMARTSCOPE_CUSTOM_CONFIG / 'custom_paths.yaml' + if not custom_paths.exists(): + logger.debug(f'No custom paths file found at {custom_paths}') + return '' + file = yaml.safe_load(custom_paths.read_text()) + key = f'detector_id_{detector.pk}' + paths = file.get(key, None) + if paths is None: + logger.debug(f'No key {key} file found at {custom_paths}') + return '' + return paths.get('frames_prefix', '') + + +def parse_frames_prefix(prefix:str, grid:AutoloaderGrid): + pattern = r'.*(\{\{.*\}\})' + matches = re.findall(pattern, prefix) + for match in matches: + clean_match = match.replace('{{', '').replace('}}', '') + split = clean_match.split('.') + x = grid + for s in split: + x = getattr(x, s) + logger.debug(f'Parsed {match} to {x}') + prefix = prefix.replace(match,x) + return prefix \ No newline at end of file diff --git a/Smartscope/core/grid/diagnostics.py b/Smartscope/core/grid/diagnostics.py index 0ca5db73..a38255f9 100644 --- a/Smartscope/core/grid/diagnostics.py +++ b/Smartscope/core/grid/diagnostics.py @@ -14,7 +14,7 @@ def generate_diagnostic_figure(image:np.array, coords_set, outputpath:Path): for coords, color, perc_im_radius in coords_set: radius = int(image_color.shape[0]*(perc_im_radius/100)) for coord in coords: - logger.info(f'{coord}, {type(coord)}') + # logger.info(f'{coord}, {type(coord)}') cv2.circle(image_color,coord,radius,color=color, thickness=cv2.FILLED) cv2.imwrite(str(outputpath), imutils.resize(image_color,512)) diff --git a/Smartscope/core/grid/grid_io.py b/Smartscope/core/grid/grid_io.py index 1a8683e2..15629da5 100644 --- a/Smartscope/core/grid/grid_io.py +++ b/Smartscope/core/grid/grid_io.py @@ -38,4 +38,10 @@ def create_dirs_docker(working_dir): def create_grid_directories(path: str) -> None: path = Path(path) for directory in [path, path / 'raw', path / 'pngs']: - directory.mkdir(exist_ok=True) \ No newline at end of file + directory.mkdir(exist_ok=True) + + @staticmethod + def create_grid_frames_directory(path, grid_dir): + directory = Path(path, grid_dir) + directory.mkdir(parents=True, exist_ok=True) + return grid_dir \ No newline at end of file diff --git a/Smartscope/core/grid/run_square.py b/Smartscope/core/grid/run_square.py index 6df01f6c..e60a4c39 100644 --- a/Smartscope/core/grid/run_square.py +++ b/Smartscope/core/grid/run_square.py @@ -13,7 +13,8 @@ from Smartscope.core.models import HoleModel from Smartscope.core.status import status from Smartscope.core.protocols import get_or_set_protocol -from Smartscope.core.db_manipulations import update, select_n_areas, add_targets, group_holes_for_BIS +from Smartscope.core.db_manipulations import update, add_targets, group_holes_for_BIS +from Smartscope.core.data_manipulations import select_n_areas from Smartscope.lib.image_manipulations import export_as_png from Smartscope.lib.image.montage import Montage @@ -34,7 +35,6 @@ def process_square_image(square, grid, microscope_id): export_as_png(montage.image, montage.png) targets, finder_method, classifier_method, _ = find_targets(montage, protocol.finders) holes = add_targets(grid, square, targets, HoleModel, finder_method, classifier_method) - square = update(square, status=status.PROCESSED, shape_x=montage.shape_x, @@ -61,7 +61,10 @@ def process_square_image(square, grid, microscope_id): for hole in holes: hole.save() logger.info(f'Picking holes on {square}') - select_n_areas(square, grid.params_id.holes_per_square, is_bis=is_bis) + selected = select_n_areas(square, grid.params_id.holes_per_square, is_bis=is_bis) + with transaction.atomic(): + for obj in selected: + update(obj, selected=True, status='queued') square = update(square, status=status.TARGETS_PICKED) if square.status == status.TARGETS_PICKED: square = update(square, diff --git a/Smartscope/core/interfaces/fakescope_interface.py b/Smartscope/core/interfaces/fakescope_interface.py index abe13752..d35ff730 100644 --- a/Smartscope/core/interfaces/fakescope_interface.py +++ b/Smartscope/core/interfaces/fakescope_interface.py @@ -141,7 +141,7 @@ def highmag(self, file='', frames=True, earlyReturn=False): destination_dir=self.microscope.scopePath ) return - movies = os.path.join(self.microscope.scopePath, 'movies') + movies = os.path.join(self.microscope.scopePath, 'movies', self.grid_dir) logger.info(f"High resolution movies are stored at {movies} in fake mode") frames = Fake.generate_fake_file( file, @@ -154,8 +154,8 @@ def highmag(self, file='', frames=True, earlyReturn=False): def connect(self): logger.info('Connecting to fake scope.') - def setup(self, saveframes, framesName=None): - pass + def setup(self, saveframes:bool, grid_dir:str, framesName=None): + self.grid_dir = grid_dir def disconnect(self, close_valves=True): logger.info('Disconnecting from fake scope.') diff --git a/Smartscope/core/interfaces/microscope_interface.py b/Smartscope/core/interfaces/microscope_interface.py index 12d9821a..12678b75 100755 --- a/Smartscope/core/interfaces/microscope_interface.py +++ b/Smartscope/core/interfaces/microscope_interface.py @@ -27,7 +27,7 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): self.disconnect() - def reset_image_shift_values(self): + def reset_image_shift_values(self, afis:bool=False): self.state.reset_image_shift_values() @abstractmethod diff --git a/Smartscope/core/interfaces/serialem_interface.py b/Smartscope/core/interfaces/serialem_interface.py index 198ab51c..ed51e37d 100644 --- a/Smartscope/core/interfaces/serialem_interface.py +++ b/Smartscope/core/interfaces/serialem_interface.py @@ -1,4 +1,4 @@ -from pathlib import PureWindowsPath +from pathlib import PureWindowsPath, Path from typing import Callable, Tuple import serialem as sem import time @@ -103,12 +103,14 @@ def set_atlas_optics_imaging_state(self, state_name:str='Atlas'): def reset_stage(self): + logger.info(f'Resetting stage to center.') sem.TiltTo(0) sem.MoveStageTo(0,0,0) def remove_slit(self): if self.detector.energyFilter: if sem.ReportEnergyFilter()[2] == 1: + logger.info('Removing slit.') sem.SetSlitIn(0) @@ -232,12 +234,12 @@ def connect(self): sem.ClearPersistentVars() sem.AllowFileOverwrite(1) - def setup(self, saveframes, framesName=None): + def setup(self, saveframes:bool, grid_dir:str='', framesName=None): if saveframes: logger.info('Saving frames enabled') sem.SetDoseFracParams('P', 1, 1, 0) - movies_directory = PureWindowsPath(self.detector.framesDir).as_posix().replace('/', '\\') - logger.info(f'Saving frames to {movies_directory}') + movies_directory = PureWindowsPath(self.detector.framesDir, grid_dir).as_posix().replace('/', '\\') + logger.info(f'SerialEM will be saving frames to {movies_directory}') sem.SetFolderForFrames(movies_directory) if framesName is not None: sem.SetFrameBaseName(0, 1, 0, framesName) diff --git a/Smartscope/core/main_commands.py b/Smartscope/core/main_commands.py index f7f1c84b..531767f1 100755 --- a/Smartscope/core/main_commands.py +++ b/Smartscope/core/main_commands.py @@ -3,18 +3,10 @@ import sys import json from django.db import transaction -from django.conf import settings from django.core.cache import cache -from Smartscope.core.models import * from Smartscope.core.test_commands import * -from Smartscope.lib.image.montage import Montage -from Smartscope.lib.image.targets import Targets -from Smartscope.lib.image_manipulations import convert_centers_to_boxes -from Smartscope.core.db_manipulations import group_holes_for_BIS, add_targets -from .autoscreen import autoscreen -from .run_grid import run_grid -from .preprocessing_pipelines import highmag_processing + import numpy as np @@ -50,6 +42,12 @@ def run(command, *args): def add_holes(id, targets): + from Smartscope.lib.image.montage import Montage + from Smartscope.lib.image.targets import Targets + from Smartscope.core.db_manipulations import add_targets + from Smartscope.core.models import SquareModel, HoleModel + from Smartscope.lib.image_manipulations import convert_centers_to_boxes + instance = SquareModel.objects.get(pk=id) montage = Montage(name=instance.name, working_dir=instance.grid_id.directory) montage.load_or_process() @@ -64,8 +62,11 @@ def add_holes(id, targets): montage=montage, target_type='hole' ) - start_number = instance.holemodel_set.order_by('-number')\ - .values_list('number', flat=True).first() + 1 + start_number = 1 + instance_number = instance.holemodel_set.order_by('-number')\ + .values_list('number', flat=True).first() + if instance_number is not None: + start_number += instance_number holes = add_targets( grid=instance.grid_id, parent=instance, @@ -111,38 +112,76 @@ def toggle_pause(microscope_id: str): open(pause_file, 'w').close() print(json.dumps(dict(pause=True))) +def select_areas(mag_level, object_id, n_areas): + from Smartscope.core.data_manipulations import select_n_areas + from Smartscope.core.models import SquareModel, AtlasModel + from Smartscope.core.db_manipulations import update + if mag_level == 'atlas': + obj = AtlasModel.objects.get(pk=object_id) + is_bis = False + else: + obj = SquareModel.objects.get(pk=object_id) + is_bis = obj.grid_id.params_id.bis_max_distance > 0 + output = select_n_areas(obj, int(n_areas), is_bis=is_bis) + logger.info(f'Selected {output} areas.') + with transaction.atomic(): + for obj in output: + update(obj, selected=True, status='queued') + print('Done.') def regroup_bis(grid_id, square_id): + from Smartscope.core.models import AutoloaderGrid, SquareModel, HoleModel + from Smartscope.core.db_manipulations import group_holes_for_BIS + from Smartscope.core.data_manipulations import filter_targets, apply_filter + from Smartscope.core.status import status grid = AutoloaderGrid.objects.get(grid_id=grid_id) - square = SquareModel.objects.get(square_id=square_id) + if square_id == 'all': + queryparams = dict(grid_id=grid_id) + else: + queryparams = dict(grid_id=grid_id, square_id=square_id) logger.debug(f"{grid_id} {square_id}") collection_params = grid.params_id logger.debug(f"Removing all holes from queue") - HoleModel.objects.filter(square_id=square,status__isnull=True)\ - .update(selected=False,status=None,bis_group=None,bis_type=None) - HoleModel.objects.filter(square_id=square,status='queued',)\ - .update(selected=False,status=None,bis_group=None,bis_type=None) - filtered_holes = HoleModel.display.filter(square_id=square,status__isnull=True) + HoleModel.objects.filter(**queryparams,status__isnull=True)\ + .update(selected=False,status=status.NULL,bis_group=None,bis_type=None) + HoleModel.objects.filter(**queryparams,status='queued',)\ + .update(selected=False,status=status.NULL,bis_group=None,bis_type=None) + + # filtered_holes = HoleModel.display.filter(**queryparams,status__isnull=True) holes_for_grouping = [] - other_holes = [] - for h in filtered_holes: - if h.is_good() and not h.is_excluded()[0] and not h.is_out_of_range(): - holes_for_grouping.append(h) + # other_holes = [] + # for h in filtered_holes: + # if h.is_good() and not h.is_excluded()[0] and not h.is_out_of_range(): + # holes_for_grouping.append(h) + squares = SquareModel.display.filter(status=status.COMPLETED,**queryparams) + for square in squares: + logger.debug(f"Filtering square {square}, {square.pk}") + targets = square.targets.filter(status__isnull=True) + filtered = filter_targets(square, targets) + holes_for_grouping += apply_filter(targets, filtered) + - logger.info(f'Filtered holes = {len(filtered_holes)}\nHoles for grouping = {len(holes_for_grouping)}') + logger.info(f'Holes for grouping = {len(holes_for_grouping)}') holes = group_holes_for_BIS( holes_for_grouping, max_radius=collection_params.bis_max_distance, min_group_size=collection_params.min_bis_group_size, - queue_all=collection_params.holes_per_square == 0 ) with transaction.atomic(): - for hole in sorted(holes, key=lambda x: x.selected): + for hole in holes: hole.save() logger.info('Regrouping BIS done.') + return squares + +def regroup_bis_and_select(grid_id, square_id): + from Smartscope.core.models import AutoloaderGrid + squares = regroup_bis(grid_id, square_id) + grid = AutoloaderGrid.objects.get(grid_id=grid_id) + for square in squares: + select_areas('square', square.pk, grid.params_id.holes_per_square) def continue_run(next_or_continue, microscope_id): @@ -194,6 +233,7 @@ def download_testfiles(overwrite=False): print('Done.') def get_atlas_to_search_offset(detector_name,maximum=0): + from Smartscope.core.models import Detector, SquareModel if isinstance(maximum, str): maximum = int(maximum) detector = Detector.objects.filter(name__contains=detector_name).first() @@ -247,6 +287,7 @@ def get_atlas_to_search_offset(detector_name,maximum=0): def export_grid(grid_id, export_to=''): + from Smartscope.core.models import AutoloaderGrid from Smartscope.core.utils.export_import import export_grid if export_to == '': export_to = os.path.join(grid.directory, 'export.yaml') @@ -265,4 +306,36 @@ def import_grid(file:str): print('Done.') +def extend_lattice(square_id): + from Smartscope.core.models import SquareModel + from Smartscope.lib.Datatypes.grid_geometry import GridGeometry, GridGeometryLevel + from Smartscope.core.mesh_rotation import calculate_hole_geometry + from Smartscope.lib.Finders.lattice_extension import lattice_extension + from Smartscope.lib.image.montage import Montage + square = SquareModel.objects.get(pk=square_id) + grid = square.grid_id + geometry = GridGeometry.load(directory=grid.directory) + rotation, spacing = geometry.get_geometry(level=GridGeometryLevel.SQUARE) + if any([rotation is None, spacing is None]): + rotation, spacing = calculate_hole_geometry(grid) + montage = Montage(name=square.name, working_dir=grid.directory) + montage.read_image() + targets = square.holemodel_set.all() + if targets.count() == 0: + print('No targets found. The square needs at least one target to center the lattice on.') + return + coords = np.array([t.coords for t in targets]) + new_targets = lattice_extension(coords, montage.image, rotation, spacing) + print(json.dumps(new_targets.tolist())) + +def highmag_processing(grid_id: str, *args, **kwargs): + from .preprocessing_pipelines import highmag_processing + highmag_processing(grid_id, *args, **kwargs) + +def autoscreen(session_id:str): + from .autoscreen import autoscreen + autoscreen(session_id=session_id) + + + \ No newline at end of file diff --git a/Smartscope/core/mesh_rotation.py b/Smartscope/core/mesh_rotation.py index cee59182..0ac9547f 100644 --- a/Smartscope/core/mesh_rotation.py +++ b/Smartscope/core/mesh_rotation.py @@ -2,6 +2,8 @@ import logging from typing import Callable from Smartscope.core.models import AutoloaderGrid, SquareModel, HoleModel +from Smartscope.lib.Datatypes.grid_geometry import GridGeometry, GridGeometryLevel +from Smartscope.lib.mesh_operations import filter_closest, get_average_angle, get_mesh_rotation_spacing # from scipy.spatial import KDTree # from scipy.signal import correlate2d from scipy.spatial.distance import cdist #, pdist @@ -28,117 +30,30 @@ def square_mesh(grid_instance): square_spacing = grid_instance.meshSize.pitch return squares,square_spacing -# def kabsch_rotation_2d(P, Q): -# """Kabsch algorithm implementation for 2D coordinates""" -# # center the point sets -# P_centered = P - np.mean(P, axis=0) -# Q_centered = Q - np.mean(Q, axis=0) -# # calculate the covariance matrix -# C = np.dot(np.transpose(P_centered), Q_centered) -# # singular value decomposition of the covariance matrix -# U, S, V = np.linalg.svd(C) -# # calculate the optimal rotation matrix -# R = np.dot(U, V) -# # calculate the rotation angle from the trace of R -# cos_theta = np.trace(R) -# sin_theta = R[1, 0] - R[0, 1] -# theta = np.degrees(np.arctan2(sin_theta, cos_theta)) -# return theta - -# def icp_rotation(P, Q, max_iterations=500, tolerance=1e-6): -# """ICP algorithm implementation to get rotation angle""" -# # initialize the rotation matrix to identity -# R = np.eye(2) -# # create a KD tree for nearest neighbor search -# tree_Q = KDTree(Q) -# # iterate until convergence -# for i in range(max_iterations): -# # find the nearest neighbors of each point in P in Q -# distances, indices = tree_Q.query(P) -# # compute the centroid of each set of points -# centroid_P = np.mean(P, axis=0) -# centroid_Q = np.mean(Q[indices], axis=0) -# # compute the centered point sets -# P_centered = P - centroid_P -# Q_centered = Q[indices] - centroid_Q -# # compute the covariance matrix -# C = np.dot(np.transpose(Q_centered), P_centered) -# # compute the SVD of the covariance matrix -# U, _, V = np.linalg.svd(C) -# # compute the optimal rotation matrix -# R_new = np.dot(U, V) -# # update the rotation matrix -# R = np.dot(R_new, R) -# # update the point set P -# P = np.dot(P, R_new.T) -# # check for convergence -# if np.abs(np.trace(R_new) - 2) < tolerance: -# print(f'tolerace reached at iteration {i}') -# break -# # compute the rotation angle from the rotation matrix -# theta = np.degrees(np.arctan2(R[1, 0], R[0, 0])) -# return theta - -# def cc_rotation(points, *args): -# # Assume we have a grid of 2D points stored as a numpy array 'points', where each row represents a point -# # Create a template that is aligned with the grid (e.g., a 1D sine wave) -# template = np.sin(np.linspace(0, 2 * np.pi, len(points))) -# # Compute the cross-correlation of the grid with the template -# corr = correlate2d(points, template[:, np.newaxis], mode='same') -# # Find the peak of the cross-correlation -# peak = np.argmax(corr) -# # Compute the angle of the grid using the phase of the peak -# angle = np.angle(np.exp(1j * 2 * np.pi * peak / len(points))) -# return np.degrees(angle) - -# def hough_rotation(points): -# h, theta, d = hough_line(np.vstack([points[:, 1], points[:, 0]])) -# # Find the peaks in the Hough transform -# peaks = hough_line_peaks(h, theta, d) -# # Compute the angle of the dominant line -# angle = np.mean(peaks[1]) -# return np.degrees(angle) - -# def PCA_rotation(points, *args): -# centered_points = points - np.mean(points, axis=0) -# # Compute the covariance matrix of the centered points -# cov = np.cov(centered_points.T) -# # Compute the eigenvectors and eigenvalues of the covariance matrix -# eigenvalues, eigenvectors = np.linalg.eig(cov) -# # Identify the index of the principal component (i.e., the eigenvector with the largest eigenvalue) -# principal_component_index = np.argmax(eigenvalues) -# # Compute the angle of the principal component -# angle = np.arctan2(eigenvectors[principal_component_index, 1], eigenvectors[principal_component_index, 0]) -# return np.degrees(angle) - -def filter_closest(points,max_dist): - distances = cdist(points,points) - out_points = [] - for ind,row in enumerate(distances): - indexes= [i[1] for i in np.argwhere([row > 0, row < max_dist]) if i[0] ==1 and i[1] != ind] - filtered = points[indexes,:] - points[ind] - out_points.extend(filtered) - return np.array(out_points) - -def atan2_firstquad(point): - angle = np.degrees(np.arctan2(point[1],point[0])) - while angle > 90: - angle-=90 - while angle < 0: - angle += 90 - return angle - -def get_average_angle(points): - angles = np.apply_along_axis(atan2_firstquad,axis=1, arr=points) - return np.mean(angles) - def get_mesh_rotation(grid:AutoloaderGrid, level:Callable=hole_mesh, algo:Callable=get_average_angle): # grid = AutoloaderGrid.objects.get(pk=grid_id) targets, mesh_spacing = level(grid) stage_coords = np.array([t.stage_coords for t in targets]) logger.debug(f'Found {len(targets)} targets. Mesh Spacing is {mesh_spacing} um.') - filtered_points= filter_closest(stage_coords, mesh_spacing*1.08) + filtered_points, _= filter_closest(stage_coords, mesh_spacing*1.08) rotation = algo(filtered_points) logger.debug(f'Calculated mesh rotation: {rotation}') return rotation + +def calculate_hole_geometry(grid:AutoloaderGrid): + targets, mesh_spacing = hole_mesh(grid) + coords = np.array([t.coords for t in targets]) + pixel_size = targets[0].parent.pixel_size + logger.debug(f'Calculating hole geometry for grid {grid} with {len(targets)} holes and mesh spacing: {mesh_spacing} um. Pixel size of {targets[0].parent}: {pixel_size} A.') + rotation, spacing = get_mesh_rotation_spacing(coords, mesh_spacing / pixel_size * 10_000) + + geometry = GridGeometry.load(directory=grid.directory) + geometry.set_geometry(level=GridGeometryLevel.SQUARE, spacing=spacing, rotation=rotation) + geometry.save(directory=grid.directory) + logger.info(f'Updated grid {grid} with rotation: {rotation} degrees and spacing: {spacing} pixels.') + return rotation, spacing + +def save_mm_geometry(grid:AutoloaderGrid): + pass + \ No newline at end of file diff --git a/Smartscope/core/models/atlas.py b/Smartscope/core/models/atlas.py index d9f08bd7..e444629d 100644 --- a/Smartscope/core/models/atlas.py +++ b/Smartscope/core/models/atlas.py @@ -27,6 +27,10 @@ def group(self): def alias_name(self): return 'Atlas' + @property + def prefix_lower(self): + return self.prefix.lower() + @property def prefix(self): return 'Atlas' diff --git a/Smartscope/core/models/grid.py b/Smartscope/core/models/grid.py index 2dbbe44a..35ea4762 100644 --- a/Smartscope/core/models/grid.py +++ b/Smartscope/core/models/grid.py @@ -81,6 +81,11 @@ def collection_mode(self): if self.params_id.holes_per_square <= 0: return 'collection' return 'screening' + + def frames_dir(self, prefix:str=''): + if prefix: + return Path(f'{prefix}_{self.parent.working_directory}', f'{self.position}_{self.name}') + return Path(self.parent.working_directory, f'{self.position}_{self.name}') @property def atlas(self): @@ -132,10 +137,10 @@ def protocol(self): @property - def directory(self): + def directory(self) -> Path: self_wd = f'{self.position}_{self.name}' wd = self.parent.directory - return os.path.join(wd, self_wd) + return Path(wd, self_wd) class Meta(BaseModel.Meta): unique_together = ('position', 'name', 'session_id') diff --git a/Smartscope/core/models/high_mag.py b/Smartscope/core/models/high_mag.py index ea2939df..958e419e 100644 --- a/Smartscope/core/models/high_mag.py +++ b/Smartscope/core/models/high_mag.py @@ -4,6 +4,9 @@ from Smartscope.core.svg_plots import drawHighMag from Smartscope.lib.image_manipulations import embed_image +from .extra_property_mixin import ExtraPropertyMixin +from .target import Target +from .hole import HoleModel class HighMagImageManager(models.Manager): @@ -16,10 +19,9 @@ def get_queryset(self): return super().get_queryset().prefetch_related('finders').prefetch_related('classifiers').prefetch_related('selectors') -from .extra_property_mixin import ExtraPropertyMixin -from .target import Target + class HighMagModel(Target, ExtraPropertyMixin): - from .hole import HoleModel + hm_id = models.CharField(max_length=30, primary_key=True, editable=False) hole_id = models.ForeignKey( @@ -35,6 +37,9 @@ class HighMagModel(Target, ExtraPropertyMixin): astig = models.FloatField(null=True) angast = models.FloatField(null=True) ctffit = models.FloatField(null=True) + tilt_axis_angle = models.FloatField(null=True) + tilt_angle = models.FloatField(null=True) + ice_thickness = models.IntegerField(null=True) # aliases objects = HighMagImageManager() display = DisplayManager() diff --git a/Smartscope/core/models/screening_session.py b/Smartscope/core/models/screening_session.py index e791b161..e788c263 100644 --- a/Smartscope/core/models/screening_session.py +++ b/Smartscope/core/models/screening_session.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from .base_model import * @@ -87,7 +88,7 @@ def directory(self): cache_key = f'{self.session_id}_directory' if (directory:=cache.get(cache_key)) is not None: logger.info(f'Session {self} directory from cache.') - return directory + return Path(directory) cwd = find_screening_session(root_directories(self),self.working_directory) cache.set(cache_key,cwd,timeout=10800) return cwd diff --git a/Smartscope/core/models/square.py b/Smartscope/core/models/square.py index db5457f4..017b9833 100644 --- a/Smartscope/core/models/square.py +++ b/Smartscope/core/models/square.py @@ -1,5 +1,6 @@ from .base_model import * + class ImageManager(models.Manager): use_for_related_fields = True @@ -92,16 +93,20 @@ def parent_stage_z(self): @property def targets(self): - return self.holemodel_set.all() + return self.holemodel_set(manager='display').all() + + @classmethod + def target_model(cls): + from .hole import HoleModel + return HoleModel # @cached_model_property(key_prefix='svg', # extra_suffix_from_function=['method'], timeout=3600) def svg(self, display_type, method): - from .hole import HoleModel from Smartscope.core.svg_plots import drawSquare - holes = list(HoleModel.display.filter(square_id=self.square_id)) + holes = list(self.target_model().display.filter(square_id=self.square_id)) sq = drawSquare(self, holes, display_type, method) return sq diff --git a/Smartscope/core/models/target.py b/Smartscope/core/models/target.py index ef3b3e23..3522e007 100644 --- a/Smartscope/core/models/target.py +++ b/Smartscope/core/models/target.py @@ -45,6 +45,14 @@ class Target(BaseModel): class Meta: abstract = True + @property + def prefix(self): + raise NotImplementedError('Prefix must be implemented in the subclass') + + @property + def prefix_lower(self): + return self.prefix.lower() + @property def group(self): return self.grid_id.session_id.group diff --git a/Smartscope/core/models/target_label.py b/Smartscope/core/models/target_label.py index aa569ae7..6078e67d 100644 --- a/Smartscope/core/models/target_label.py +++ b/Smartscope/core/models/target_label.py @@ -3,6 +3,7 @@ ''' from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType +from django.utils import timezone import math from .base_model import * @@ -17,10 +18,12 @@ class TargetLabel(BaseModel): object_id = models.CharField(max_length=30) content_object = GenericForeignKey('content_type', 'object_id') method_name = models.CharField(max_length=50, null=True) + created_at = models.DateTimeField(auto_now_add=True) class Meta: abstract = True app_label = 'API' + ordering = ['-created_at'] class Finder(TargetLabel): @@ -54,3 +57,4 @@ class Selector(TargetLabel): class Meta(BaseModel.Meta): db_table = 'selector' + diff --git a/Smartscope/core/pipelines/cryosparc_live.py b/Smartscope/core/pipelines/cryosparc_live.py index 1c69e448..2ecd72cf 100644 --- a/Smartscope/core/pipelines/cryosparc_live.py +++ b/Smartscope/core/pipelines/cryosparc_live.py @@ -9,7 +9,7 @@ from Smartscope.core.db_manipulations import websocket_update from Smartscope.core.models.grid import AutoloaderGrid -from Smartscope.lib.preprocessing_methods import get_CTFFIN4_data, \ +from Smartscope.lib.preprocessing_methods import get_CTFFIND5_data, \ process_hm_from_average, process_hm_from_frames, processing_worker_wrapper from Smartscope.core.models.models_actions import update_fields diff --git a/Smartscope/core/pipelines/smartscope_preprocessing_pipeline.py b/Smartscope/core/pipelines/smartscope_preprocessing_pipeline.py index 65ff4358..2438bfe1 100644 --- a/Smartscope/core/pipelines/smartscope_preprocessing_pipeline.py +++ b/Smartscope/core/pipelines/smartscope_preprocessing_pipeline.py @@ -1,15 +1,16 @@ from functools import partial import time -from typing import Dict +from typing import Dict, List import multiprocessing import logging from pathlib import Path from Smartscope.core.db_manipulations import websocket_update +from Smartscope.core.frames import get_frames_prefix, parse_frames_prefix from Smartscope.core.models.grid import AutoloaderGrid -from Smartscope.lib.preprocessing_methods import get_CTFFIN4_data, \ +from Smartscope.lib.preprocessing_methods import get_CTFFIND5_data, \ process_hm_from_average, process_hm_from_frames, processing_worker_wrapper from Smartscope.core.models.models_actions import update_fields @@ -29,7 +30,7 @@ class SmartscopePreprocessingPipeline(PreprocessingPipeline): description = 'Default CPU-based Processing pipeline using IMOD alignframe and CTFFIND4.' to_process_queue = multiprocessing.JoinableQueue() processed_queue = multiprocessing.Queue() - child_process = [] + child_process: List[multiprocessing.Process] = [] to_update = [] incomplete_processes = [] cmdkwargs_handler = SmartScopePreprocessingCmdKwargs @@ -41,9 +42,12 @@ def __init__(self, grid: AutoloaderGrid, cmd_data:Dict): self.detector = self.grid.session_id.detector_id self.cmd_data = self.cmdkwargs_handler.parse_obj(cmd_data) logger.debug(self.cmd_data) - self.frames_directory = [Path(self.detector.frames_directory)] + self.frames_directory = [Path(self.detector.frames_directory, grid.frames_dir(prefix=parse_frames_prefix(get_frames_prefix(grid),grid)))] + if self.cmd_data.frames_directory is not None: self.frames_directory.append(self.cmd_data.frames_directory) + frames_dirs_str = '\n\t-'.join([str(x) for x in self.frames_directory]) + logger.info(f"Looking for frames in the following directories: {frames_dirs_str}") def clear_queue(self): while True: @@ -53,19 +57,24 @@ def clear_queue(self): except multiprocessing.queues.Empty: break - def start(self): - session = self.grid.session_id - logger.info(f'Starting the preprocessing with {self.cmd_data.n_processes}') - for n in range(int(self.cmd_data.n_processes)): + def start_processes(self, n_processes): + for n in range(int(n_processes)): proc = multiprocessing.Process( target=processing_worker_wrapper, - args=(session.directory, self.to_process_queue,), + args=(self.grid.directory, self.to_process_queue,), kwargs={'output_queue': self.processed_queue} ) proc.start() self.child_process.append(proc) + + + def start(self): + # session = self.grid.session_id + logger.info(f'Starting the preprocessing with {self.cmd_data.n_processes}') + self.start_processes(self.cmd_data.n_processes) self.list_incomplete_processes() while not self.is_stop_file(): + self.check_children_processes() self.queue_incomplete_processes() self.to_process_queue.join() self.check_for_update() @@ -102,6 +111,19 @@ def queue_incomplete_processes(self): self.to_process_queue.put( [from_frames, [], dict(name=obj.name, frames_file_name=obj.frames)] ) + + def check_children_processes(self): + logger.info(f'Checking the status of children processes.') + for proc in self.child_process: + logger.debug(f'Checking process {proc}') + proc.join(1) + logger.debug(f'Process {proc} joined') + if proc.is_alive(): + logger.debug(f'Child process {proc} is still alive.') + continue + logger.error(f'Child process {proc} has died. Restarting it.') + self.child_process.remove(proc) + self.start_processes(1) def stop(self): for proc in self.child_process: @@ -111,21 +133,22 @@ def stop(self): logger.debug('Process joined') def check_for_update(self): + logger.info(f'Checking for updates.') while self.processed_queue.qsize() > 0: movie = self.processed_queue.get() data = dict() if not movie.check_metadata(): data['status'] = 'skipped' - filtered_instances = list(filter(lambda x: x.name == movie.name, self.incomplete_processes)) - if len(filtered_instances) != 1: + filtered_instance = next(filter(lambda x: x.name == movie.name, self.incomplete_processes),None) + if filtered_instance is None: logger.error(f'Could not find {movie.name} in {self.incomplete_processes}. Will try again on the next cycle.') continue - if instance.status != 'skipped': - self.to_update.append(update_fields(instance, data)) + if filtered_instance.status != 'skipped': + self.to_update.append(update_fields(filtered_instance, data)) continue logger.debug(f'Updating {movie.name}') try: - data = get_CTFFIN4_data(movie.ctf) + data = get_CTFFIND5_data(movie.ctf) except Exception as err: logger.exception(err) logger.info(f'An error occured while getting CTF data from {movie.name}. Will try again later.') diff --git a/Smartscope/core/preprocessing_pipelines.py b/Smartscope/core/preprocessing_pipelines.py index 78e62980..901e4751 100755 --- a/Smartscope/core/preprocessing_pipelines.py +++ b/Smartscope/core/preprocessing_pipelines.py @@ -12,7 +12,7 @@ from .pipelines import PreprocessingPipelineCmd, SmartscopePreprocessingPipeline, CryoSPARCPipeline -PREPROCESSING_PIPELINE_FACTORY = dict(smartscopePipeline=SmartscopePreprocessingPipeline, cryoSPARC=CryoSPARCPipeline) +PREPROCESSING_PIPELINE_FACTORY = dict(smartscopePipeline=SmartscopePreprocessingPipeline,) #cryoSPARC=CryoSPARCPipeline) def load_preprocessing_pipeline(file:Path): if file.exists(): @@ -29,28 +29,26 @@ def highmag_processing(grid_id: str, *args, **kwargs) -> None: try: grid = AutoloaderGrid.objects.get(grid_id=grid_id) os.chdir(grid.directory) - # logging.getLogger('Smartscope').handlers.pop() - # logger.debug(f'Log handlers:{logger.handlers}') add_log_handlers(directory=grid.session_id.directory, name='proc.out') logger.debug(f'Log handlers:{logger.handlers}') preprocess_file = Path('preprocessing.json') cmd_data = load_preprocessing_pipeline(preprocess_file) if cmd_data is None: logger.info('Trying to load preprocessing parameters from command line arguments.') - cmd_data = PreprocessingPipelineCmd.parse_obj(**kwargs) + cmd_data = PreprocessingPipelineCmd.model_validate(**kwargs) if cmd_data.is_running(): logger.info(f'Processings with PID:{cmd_data.process_pid} seem to already be running, '+ \ 'please kill the other one before continuing.') return cmd_data.process_pid=os.getpid() - preprocess_file.write_text(cmd_data.json()) + preprocess_file.write_text(cmd_data.model_dump_json()) pipeline = PREPROCESSING_PIPELINE_FACTORY[cmd_data.pipeline](grid, cmd_data.kwargs) pipeline.start() - except Exception as e: logger.exception(e) except KeyboardInterrupt as e: - logger.exception(e) + logger.info('SIGINT recieved by the highmag_processing.') finally: logger.debug('Wrapping up') - pipeline.stop() + if 'pipeline' in locals(): + pipeline.stop() diff --git a/Smartscope/core/run_grid.py b/Smartscope/core/run_grid.py index c6299eee..baf000ae 100644 --- a/Smartscope/core/run_grid.py +++ b/Smartscope/core/run_grid.py @@ -2,9 +2,11 @@ import sys import time import logging +from enum import Enum from pathlib import Path from django.utils import timezone from django.conf import settings +from django.db import transaction logger = logging.getLogger(__name__) @@ -20,18 +22,21 @@ from Smartscope.core.selectors import selector_wrapper from Smartscope.core.models import ScreeningSession, SquareModel, AutoloaderGrid from Smartscope.core.settings.worker import PROTOCOL_COMMANDS_FACTORY +from Smartscope.core.frames import get_frames_prefix, parse_frames_prefix +from Smartscope.core.mesh_rotation import calculate_hole_geometry from Smartscope.core.status import status from Smartscope.core.protocols import get_or_set_protocol from Smartscope.core.preprocessing_pipelines import load_preprocessing_pipeline -from Smartscope.core.db_manipulations import update, select_n_areas, queue_atlas, add_targets +from Smartscope.core.db_manipulations import update, queue_atlas, add_targets +from Smartscope.core.data_manipulations import select_n_areas from Smartscope.lib.image_manipulations import export_as_png + def run_grid( grid:AutoloaderGrid, session:ScreeningSession, - scope:MicroscopeInterface ): #processing_queue:multiprocessing.JoinableQueue, """Main logic for the SmartScope process @@ -43,8 +48,11 @@ def run_grid( session_id = session.pk microscope = session.microscope_id + + grid = update(grid, refresh_from_db=True, last_update=None) # Set the Websocket_update_decorator grid property update.grid = grid + if grid.status == GridStatus.COMPLETED: logger.info(f'Grid {grid.name} already complete. grid ID={grid.grid_id}') return @@ -55,7 +63,6 @@ def run_grid( return logger.info(f'Starting {grid.name}, status={grid.status}') - grid = update(grid, refresh_from_db=True, last_update=None) if grid.status is GridStatus.NULL: grid = update(grid, status=GridStatus.STARTED, start_time=timezone.now()) @@ -71,9 +78,16 @@ def run_grid( atlas = queue_atlas(grid) # scope + + # create frames directory + prefix = parse_frames_prefix(get_frames_prefix(grid),grid) + grid_dir = grid.frames_dir(prefix=prefix) + if params.save_frames: + GridIO.create_grid_frames_directory(session.detector_id.frames_directory, grid.frames_dir(prefix=prefix)) + logger.debug(f'Saving the frames in {grid_dir}') scope.loadGrid(grid.position) is_stop_file(session_id) - scope.setup(params.save_frames, framesName=f'{session.date}_{grid.name}') + scope.setup(params.save_frames,grid_dir=grid_dir,framesName=f'{session.date}_{grid.name}') scope.reset_state() # run acquisition @@ -123,7 +137,10 @@ def run_grid( # if atlas.status == status.PROCESSED: selector_wrapper(protocol.atlas.targets.selectors, atlas, n_groups=5) - select_n_areas(atlas, grid.params_id.squares_num) + selected = select_n_areas(atlas, grid.params_id.squares_num) + with transaction.atomic(): + for obj in selected: + update(obj, selected=True, status='queued') atlas = update(atlas, status=status.COMPLETED) #Release atlas items from memory. @@ -143,10 +160,12 @@ def run_grid( break else: square, hole = get_queue(grid) + priority = get_target_priority(grid, (square, hole)) + logger.debug(f'Priority: {priority}') logger.info(f'Queued => Square: {square}, Hole: {hole}') logger.info(f'Targets done: {is_done}') - if hole is not None and (square is None or grid.collection_mode == 'screening'): + if priority == TargetPriority.HOLE: is_done = False logger.info(f'Running Hole {hole}') # process medium image @@ -196,7 +215,7 @@ def run_grid( scope.refineZLP(params.zeroloss_delay) scope.collectHardwareDark(params.hardwaredark_delay) scope.flash_cold_FEG(params.coldfegflash_delay) - elif square is not None: + elif priority == TargetPriority.SQUARE: is_done = False logger.info(f'Running Square {square}') # process square @@ -211,6 +230,7 @@ def run_grid( ) square = update(square, status=status.ACQUIRED, completion_time=timezone.now()) RunSquare.process_square_image(square, grid, microscope) + # calculate_hole_geometry(grid) elif is_done: microscope_id = session.microscope_id.pk tmp_file = os.path.join(settings.TEMPDIR, f'.pause_{microscope_id}') @@ -239,7 +259,24 @@ def run_grid( logger.info('Grid finished') return 'finished' +class TargetPriority(Enum): + HOLE = 'hole' + SQUARE = 'square' + +def get_target_priority(grid, queue): + square, hole = queue + if hole is None and square is None: + return + if hole is None: + return TargetPriority.SQUARE + if square is None: + return TargetPriority.HOLE + if grid.collection_mode == 'screening' and grid.session_id.microscope_id.vendor != 'JEOL': + return TargetPriority.HOLE + return TargetPriority.SQUARE + + def get_queue(grid): square = grid.squaremodel_set.filter(selected=True).\ diff --git a/Smartscope/core/selector_sorter.py b/Smartscope/core/selector_sorter.py new file mode 100644 index 00000000..2d17184a --- /dev/null +++ b/Smartscope/core/selector_sorter.py @@ -0,0 +1,236 @@ +import numpy as np +from pathlib import Path +from typing import List, Optional, Callable +import logging +from matplotlib import cm +from matplotlib.colors import rgb2hex +from math import floor, ceil +from pydantic import BaseModel, field_validator + +from . import models +from .settings.worker import PLUGINS_FACTORY + +logger = logging.getLogger(__name__) + + +class LagacySorterError(Exception): + pass + +class NotSetError(Exception): + pass + +class SelectorSorterData(BaseModel): + selector_name: str + low_limit: float + high_limit: Optional[float] = None + + @field_validator('low_limit', 'high_limit', mode='before') + def check_low_limit(cls, value): + if isinstance(value, str) and value.isnumeric(): + return float(value) + return value + + def create_sorter(self,): + sorter = SelectorSorter(self.selector_name) + sorter.limits = [self.low_limit, self.high_limit] + return sorter + + @property + def file_name(self): + selector_name = self.selector_name.replace(' ', '_') + return f'{selector_name.lower()}_data.json' + + def save(self, grid_directory): + with open(grid_directory / self.file_name, 'w') as f: + f.write(self.model_dump_json()) + + @classmethod + def exists(cls, directory:Path, selector_name:str): + selector_name = selector_name.replace(' ', '_') + return (directory / f'{selector_name.lower()}_data.json').exists() + + @classmethod + def load(cls, directory:Path, selector_name): + logger.info(f'Loading selector data from {directory} for {selector_name}') + selector_name = selector_name.replace(' ', '_') + with open(directory / f'{selector_name.lower()}_data.json', 'r') as f: + data = f.read() + return cls.model_validate_json(data) + + def delete(self, grid_directory): + (grid_directory / self.file_name).unlink() + + @classmethod + def parse_sorter(cls, sorter): + return cls(selector_name=sorter.selector_name, low_limit=sorter.limits[0], high_limit=sorter.limits[1]) + +def save_to_grid_directory(grid_id): + grid = models.AutoloaderGrid.objects.get(grid_id=grid_id) + return grid.directory + +def save_to_session_directory(grid_id): + grid = models.AutoloaderGrid.objects.get(grid_id=grid_id) + return grid.session_id.directory + +def save_selector_data(grid_id, selector_name:str, data:dict,save_to:Callable=save_to_grid_directory) -> SelectorSorterData: + selector_data = SelectorSorterData(selector_name=selector_name,**data) + save_directory = save_to(grid_id) + selector_data.save(save_directory) + return selector_data + +class SelectorValueParser: + + def __init__(self, selector_name:str, from_server=False): + self._selector_name = selector_name + self._from_server = from_server + + def get_selector_value(self,target): + if self._from_server: + return self.get_selector_value_from_server(target) + return self.get_selector_value_from_worker(target) + + def get_selector_value_from_worker(self,target): + return next(filter(lambda x: x.method_name == self._selector_name ,target.selectors)).value + + def get_selector_value_from_server(self,target): + return next(filter(lambda x: x.method_name == self._selector_name ,target.selectors.all())).value + + def extract_values(self, targets:List[models.Target]) -> List[float]: + values = list(map(self.get_selector_value,targets)) + if all([value == None for value in values]): + raise LagacySorterError('No values found in targets. Reverting to lagacy sorting.') + return values + +class SelectorSorter: + _limits = None + _classes:List = None + _labels:List = None + _colors:List = None + _values:List = None + + def __init__(self,selector_name:str, n_classes=5, fractional_limits:List[float]=None): + self.selector_name= selector_name + self._n_classes = n_classes + self._fractional_limits = fractional_limits + + # def __getitem__(self, index): + # return self._targets[index], *self.labels[index] + + @property + def classes(self): + if self._classes is None: + self.calculate_classes() + return self._classes + + @property + def labels(self): + if self._labels is None: + self.set_labels() + return self._labels + + @property + def limits(self): + if self._limits is None: + self.set_limits() + return self._limits + + @property + def colors(self): + if self._colors is None: + self.set_colors() + return self._colors + + @property + def values(self): + if self._values is None: + raise NotSetError('Values have not been set.') + return self._values + + @values.setter + def values(self, values:List[float]): + self._values = values + + @limits.setter + def limits(self, value:List[float]): + self._limits = value + self._classes = None + + @property + def values_range(self) -> List[float]: + return [floor(min(self.values)), ceil(max(self.values))] + + def set_limits(self): + range_ = max(self.values) - min(self.values) + self._limits = np.array(self._fractional_limits) * range_ + min(self.values) + + def set_labels(self): + logger.debug(f'Getting colored classes from selector {self.selector_name}. Inputs {len(self.values)} targets and {self._n_classes} classes with {self.limits} limits.') + # classes, limits = self.classes(self._targets, n_classes=n_classes, limits=limits) + colors = self.set_colors() + logger.debug(f'Colors are {colors}') + colored_classes = list(map(lambda x: (colors[x], x, 'Cluster' ) if x != 0 else ((colors[x], 0, 'Rejected')), self.classes)) + logger.debug(f'Colored classes are {colored_classes}') + self._labels = colored_classes + return colored_classes + + def calculate_classes(self): + # logger.debug(f'Getting classes from selector {self._selector.name}. Inputs {len(self._targets)} targets and {self._n_classes} with limits {self.limits}.') + map_in_bounds = self.included_in_limits() + step = np.diff(self.limits) / (self._n_classes) + + # for value, in_bounds in zip(values, map_in_bounds): + def get_class(value, in_bounds) -> int: + if not in_bounds: + return 0 + if value == self.limits[1]: + return self._n_classes + return int(np.floor((value - self.limits[0]) / step) + 1) + + self._classes = list(map(get_class, self.values, map_in_bounds)) + logger.debug(f'Classes are {self._classes}') + return self._classes + + # def draw(self, n_classes=5, limits=None): + # if hasattr(self._selector, 'drawMethod'): + # return self._selector.draw_method(self._targets, n_classes, limits) + + + + def included_in_limits(self): + if self.limits is None: + self.set_limits() + def selector_value_within_limits(target_value): + # logger.debug(f'Checking if {target_value} is within {self.limits}') + return self.limits[0] <= target_value <= self.limits[1] + + selector_value_within_limits = map(selector_value_within_limits,self.values) + return list(selector_value_within_limits) + + def set_colors(self): + colors = list() + cmap = cm.plasma + cmap_step = int(floor(cmap.N / self._n_classes)) + for c in range(cmap.N, 0, -cmap_step): + colors.append(rgb2hex(cmap(c))) + continue + + self._colors = colors + return colors + + +def check_directories_for_selector_data(grid:models.AutoloaderGrid, selector_name:str) -> Path: + priority = [grid.directory, grid.session_id.directory] + for directory in priority: + if SelectorSorterData.exists(directory, selector_name): + return directory + + +def initialize_selector(grid: models.AutoloaderGrid, selector:str, queryset) -> SelectorSorter: + selector_sorter = SelectorSorter(selector_name=selector,fractional_limits=PLUGINS_FACTORY[selector].limits) + directory = check_directories_for_selector_data(grid,selector) + if directory is not None: + selector_data = SelectorSorterData.load(directory, selector) + selector_sorter = selector_data.create_sorter() + selector_data = SelectorValueParser(selector, from_server=True) + selector_sorter.values = selector_data.extract_values(queryset) + return selector_sorter \ No newline at end of file diff --git a/Smartscope/core/selectors.py b/Smartscope/core/selectors.py index 0fdafb47..47aea69d 100755 --- a/Smartscope/core/selectors.py +++ b/Smartscope/core/selectors.py @@ -1,6 +1,7 @@ import numpy as np from django.db import transaction import cv2 +from typing import Optional from django.contrib.contenttypes.models import ContentType from django.db.models.query import prefetch_related_objects @@ -14,6 +15,12 @@ logger = logging.getLogger(__name__) +def generate_selector(parent, target,value:float, label:Optional[str]=None): + return dict(content_type=ContentType.objects.get_for_model(target), + object_id=target.pk, + value=value, + label=label) + def generate_equal_clusters(parent, targets, n_groups, extra_fields=dict()): output = list() if len(targets) > 0: @@ -31,9 +38,8 @@ def generate_equal_clusters(parent, targets, n_groups, extra_fields=dict()): def cluster_by_field(parent, n_groups, field='area', **kwargs): - targets = np.array(parent.targets.order_by(field)) - return generate_equal_clusters(parent, targets, n_groups) + return list(map(lambda x: generate_selector(parent,x, value=getattr(x,field)), targets)) def gray_level_selector(parent, n_groups, save=True, montage=None): @@ -43,28 +49,34 @@ def gray_level_selector(parent, n_groups, save=True, montage=None): if montage is None: montage = Montage(**parent.__dict__, working_dir=parent.grid_id.directory) montage.create_dirs() - if save: - img = cv2.bilateralFilter(auto_contrast(montage.image.copy()), 30, 75, 75) + # if save: + img = auto_contrast(montage.image) for target in targets: finder = list(target.finders.all())[0] x, y = finder.x, finder.y - target.median = np.mean(img[y - target.radius:y + target.radius, x - target.radius:x + target.radius]) - if save: - cv2.circle(img, (x, y), target.radius, target.median, 10) + extracted = img[y - target.radius:y + target.radius, x - target.radius:x + target.radius] + target.median = np.mean(extracted) + target.std = np.std(extracted) + # if save: + # cv2.circle(img, (x, y), target.radius, target.median, 10) - if save: - save_image(img, 'gray_level_selector', extension='png', destination=parent.directory, resize_to=1024) + # if save: + # save_image(img, 'gray_level_selector', extension='png', destination=parent.directory, resize_to=1024) targets.sort(key=lambda x: x.median) return generate_equal_clusters(parent, targets, n_groups, extra_fields=dict(value='median')) +def run_selector(selector_name,selection,*args, **kwargs): + method = PLUGINS_FACTORY[selector_name] + outputs = method.run(selection, *args, **kwargs) + with transaction.atomic(): + for obj in outputs: + Selector(**obj, method_name=method.name).save() + def selector_wrapper(selectors, selection, *args, **kwargs): logger.info(f'Running selectors {selectors} on {selection}') for method in selectors: - method = PLUGINS_FACTORY[method] + run_selector(method, selection, *args, **kwargs) + - outputs = method.run(selection, *args, **kwargs) - with transaction.atomic(): - for obj in outputs: - Selector(**obj, method_name=method.name).save() diff --git a/Smartscope/core/settings/server_docker.py b/Smartscope/core/settings/server_docker.py index fdce367f..c60d11f0 100755 --- a/Smartscope/core/settings/server_docker.py +++ b/Smartscope/core/settings/server_docker.py @@ -63,6 +63,7 @@ 'corsheaders', 'Smartscope.core.settings.apps.Frontend', 'Smartscope.core.settings.apps.API', + 'Smartscope.server.selector_viewer', ] MIDDLEWARE = [ diff --git a/Smartscope/core/settings/worker.py b/Smartscope/core/settings/worker.py index 916518c8..d1f3d3fd 100644 --- a/Smartscope/core/settings/worker.py +++ b/Smartscope/core/settings/worker.py @@ -1,4 +1,5 @@ import os +import sys from pathlib import Path from Smartscope.core.config import register_plugins, register_protocols, \ register_external_plugins, get_active_plugins_list, get_protocol_commands @@ -10,6 +11,7 @@ SMARTSCOPE_DEFAULT_PLUGINS = SMARTSCOPE_DEFAULT_CONFIG / 'plugins' SMARTSCOPE_CUSTOM_PLUGINS = SMARTSCOPE_CUSTOM_CONFIG / 'plugins' EXTERNAL_PLUGINS_DIRECTORY = Path(os.getenv('EXTERNAL_PLUGINS_DIRECTORY')) +sys.path.append(str(EXTERNAL_PLUGINS_DIRECTORY)) EXTERNAL_PLUGINS_LIST:list = get_active_plugins_list( EXTERNAL_PLUGINS_DIRECTORY, SMARTSCOPE_CUSTOM_CONFIG / 'external_plugins.txt' diff --git a/Smartscope/core/svg_plots.py b/Smartscope/core/svg_plots.py index 4a88e54c..f76b64b6 100644 --- a/Smartscope/core/svg_plots.py +++ b/Smartscope/core/svg_plots.py @@ -79,7 +79,7 @@ def css_color(obj, display_type, method): return PLUGINS_FACTORY[method].get_label(labels[0]) def drawAtlas(atlas, targets, display_type, method) -> draw.Drawing: - d = draw.Drawing(atlas.shape_y, atlas.shape_x, id='square-svg', displayInline=False, style_='height: 100%; width: 100%') + d = draw.Drawing(atlas.shape_y, atlas.shape_x, id='atlas-svg', displayInline=False, style_='height: 100%; width: 100%') d.append(draw.Image(0, 0, d.width, d.height, path=atlas.png, embed= not atlas.is_aws)) shapes = draw.Group(id='atlasShapes') @@ -121,6 +121,46 @@ def drawAtlas(atlas, targets, display_type, method) -> draw.Drawing: d.append(add_legend(set(labels_list), d.width, d.height, atlas.pixel_size)) return d +def drawAtlasNew(atlas, selector_sorter) -> draw.Drawing: + d = draw.Drawing(atlas.shape_y, atlas.shape_x, id=f'{atlas.prefix_lower}-svg', displayInline=False, style_='height: 100%; width: 100%') + d.append(draw.Image(0, 0, d.width, d.height, path=atlas.png, embed= not atlas.is_aws)) + + shapes = draw.Group(id=f'{atlas.prefix_lower}Shapes') + text = draw.Group(id=f'{atlas.prefix_lower}Text') + + labels_list = [] + for i, (color, label, prefix) in zip(atlas.targets, selector_sorter.labels): + if color is not None: + sz = floor(sqrt(i.area)) + finder = list(i.finders.all())[0] + if not finder.is_position_within_stage_limits(): + color = '#505050' + label = 'Out of range' + x = finder.x - sz // 2 + y = (finder.y - sz // 2) + r = draw.Rectangle(x, y, sz, sz, id=i.pk, stroke_width=floor(d.width / 300), stroke=color, fill=color, fill_opacity=0, label=label, + class_=f'target', status=i.status, onclick=f"click{i.prefix}(this)") + + if i.selected: + ft_sz = floor(d.width / 35) + t = draw.Text(str(i.number), ft_sz, x=x + sz, y=y, id=f'{i.pk}_text', paint_order='stroke', + stroke_width=floor(ft_sz / 5), stroke=color, fill='white', class_=f'svgtext {i.status}') + text.append(t) + r.args['class'] += f" {i.status}" + # if i.status == 'completed': + # if i.has_active: + # r.args['class'] += ' has_active' + # elif i.has_queued: + # r.args['class'] += ' has_queued' + # elif i.has_completed: + # r.args['class'] += ' has_completed' + labels_list.append((color, label, prefix)) + shapes.append(r) + d.append(shapes) + d.append(text) + d.append(add_scale_bar(atlas.pixel_size, d.width, d.height)) + d.append(add_legend(set(labels_list), d.width, d.height, atlas.pixel_size)) + return d def drawSquare(square, targets, display_type, method) -> draw.Drawing: d = draw.Drawing(square.shape_y, square.shape_x, id='square-svg', displayInline=False, style_='height: 100%; width: 100%') @@ -209,4 +249,22 @@ def drawHighMag(highmag) -> draw.Drawing: d = draw.Drawing(highmag.shape_y, highmag.shape_x, id=f'{highmag.name}-svg', displayInline=False, style_='height: 100%; width: 100%') d.append(draw.Image(0, 0, d.width, d.height, path=highmag.png, embed= not highmag.is_aws)) d.append(add_scale_bar(highmag.pixel_size, d.width, d.height, id_type=highmag.name)) + return d + + +def drawSelector(obj, selector_sorter) -> draw.Drawing: + d = draw.Drawing(obj.shape_y, obj.shape_x, id='selector-svg', displayInline=False, style_='height: 100%; width: 100%') + d.append(draw.Image(0, 0, d.width, d.height, path=obj.png, embed= not obj.is_aws)) + + shapes = draw.Group(id='selectorShapes') + + for index, i in enumerate(obj.targets): + sz = floor(sqrt(i.area)) + finder = list(i.finders.all())[0] + x = finder.x - sz // 2 + y = (finder.y - sz // 2) + color = 'lightgreen' if selector_sorter.classes[index] != 0 else 'red' + r = draw.Rectangle(x, y, sz, sz, id=i.pk, stroke_width=floor(d.width / 300), stroke=color, fill=color, fill_opacity=0, class_='selectorTarget',value_=selector_sorter.values[index],) + shapes.append(r) + d.append(shapes) return d \ No newline at end of file diff --git a/Smartscope/core/test_commands.py b/Smartscope/core/test_commands.py index f51390ab..63943cf9 100755 --- a/Smartscope/core/test_commands.py +++ b/Smartscope/core/test_commands.py @@ -7,10 +7,8 @@ import time from django.conf import settings -from .interfaces.tfsserialem_interface import TFSSerialemInterface -from Smartscope.lib.preprocessing_methods import process_hm_from_frames -from .grid.finders import find_targets -from Smartscope.core.models import Microscope + + logger = logging.getLogger(__name__) @@ -40,9 +38,10 @@ def test_high_mag_frame_processing( test_dir = autoscreen_dir + group + session name = grid_id ''' + from Smartscope.lib.preprocessing_methods import process_hm_from_frames os.chdir(test_dir) - frames_file_name = '20211119_AR2_0723-1_5383.tif' - frames_dirs = [Path(os.getenv('AUTOSCREENDIR')), Path(os.getenv('TEST_FILES'), 'highmag_frames')] + frames_file_name = '20230321_AB_0317_2_3302_0.0.tif' + frames_dirs = [Path(os.getenv('AUTOSCREENDIR')), Path(os.getenv('TEST_FILES'), 'highmagframes')] movie = process_hm_from_frames(name, frames_file_name=frames_file_name, frames_directories=frames_dirs) print(f'All movie data: {movie.check_metadata()}') @@ -101,6 +100,7 @@ def refine_pixel_size_from_targets(instances, spacings) -> Tuple[float, float]: def test_finder(plugin_name: str, raw_image_path: str, output_dir: str, repeats=1): # output_dir='/mnt/data/testing/' from Smartscope.lib.image.montage import Montage from Smartscope.lib.image_manipulations import auto_contrast, save_image + from .grid.finders import find_targets import cv2 import math @@ -187,4 +187,11 @@ def restore_db(backup_file=None,backup_directory='/mnt/backups/'): logger.info(f'Restoring database to {backup_file}') command = f"mysql --user={os.getenv('MYSQL_USER')} --password={os.getenv('MYSQL_PASSWORD')} --host={os.getenv('MYSQL_HOST')} --port={os.getenv('MYSQL_PORT')} {os.getenv('MYSQL_DATABASE')} < {backup_file}" subprocess.call(command, shell=True) - logger.info('Finished restoring database.') \ No newline at end of file + logger.info('Finished restoring database.') + +def test_find_hole_geometry(grid_id): + from Smartscope.core.models import AutoloaderGrid + from Smartscope.core.mesh_rotation import save_hole_geometry + grid = AutoloaderGrid.objects.get(pk=grid_id) + rotation, spacing = save_hole_geometry(grid) + print(f'Updated grid {grid} with rotation: {rotation} degrees and spacing: {spacing} pixels.') \ No newline at end of file diff --git a/Smartscope/core/tests/protocols.py b/Smartscope/core/tests/protocols.py index 08bc17af..e9d9a6c8 100644 --- a/Smartscope/core/tests/protocols.py +++ b/Smartscope/core/tests/protocols.py @@ -1,4 +1,5 @@ import Smartscope.bin.smartscope +import pytest from ..run_grid import parse_method, runAcquisition from ..interfaces.utils import generate_mock_fake_scope_interface diff --git a/Smartscope/core/tests/test_config.py b/Smartscope/core/tests/test_config.py new file mode 100644 index 00000000..769e7946 --- /dev/null +++ b/Smartscope/core/tests/test_config.py @@ -0,0 +1,50 @@ +from ..config import PluginFactory +from pathlib import Path +from Smartscope.lib.Datatypes.base_plugin import Finder + +PLUGINS_DIRECTORY = '/opt/smartscope/config/smartscope/plugins' +EXTERNAL_PLUGINS_DIRECTORY = '/opt/smartscope/external_plugins' + +def test_parse_pluging_directory(): + factory = PluginFactory(plugins_directory=PLUGINS_DIRECTORY) + factory.parse_plugins_directory() + assert len(factory._plugins_list) == 9 + +def test_read_plugins_files(): + factory = PluginFactory(plugins_directory=PLUGINS_DIRECTORY) + factory.parse_plugins_directory() + factory.read_plugins_files() + assert len(factory._plugins_data) == 9 + +def test_register_plugin(): + factory = PluginFactory(plugins_directory=PLUGINS_DIRECTORY) + factory.parse_plugins_directory() + factory.read_plugins_files() + factory.register_plugin('Manual finder') + assert len(factory._factory) == 1 + assert 'Manual finder' in factory._factory + assert isinstance(factory._factory['Manual finder'], Finder) + +def test_register_plugins(): + factory = PluginFactory(plugins_directory=PLUGINS_DIRECTORY) + factory.parse_plugins_directory() + factory.register_plugins() + assert len(factory._factory) == 9 + +def test_get_plugins(): + factory = PluginFactory(plugins_directory=PLUGINS_DIRECTORY) + factory.parse_plugins_directory() + factory.register_plugins() + plugins = factory.get_plugins() + assert len(plugins) == 9 + assert 'Manual finder' in plugins + assert isinstance(plugins['Manual finder'], Finder) + +def test_get_plugin(): + factory = PluginFactory(plugins_directory=PLUGINS_DIRECTORY) + factory.parse_plugins_directory() + factory.read_plugins_files() + factory.register_plugins() + plugin = factory.get_plugin('Manual finder') + assert isinstance(plugin, Finder) + diff --git a/Smartscope/core/tests/test_data_manipulations.py b/Smartscope/core/tests/test_data_manipulations.py new file mode 100644 index 00000000..c043f81d --- /dev/null +++ b/Smartscope/core/tests/test_data_manipulations.py @@ -0,0 +1,36 @@ +import Smartscope.bin.smartscope + +from Smartscope.core.data_manipulations import get_target_methods, filter_targets, randomized_choice, choose_get_index +from Smartscope.core.models import SquareModel + +def test_get_target_methods(): + parent = SquareModel.objects.get(pk='grid1_square11BeZGKcjKyUKT3nQh') + methods = get_target_methods(parent, 'selectors') + assert len(methods) == 1 + +def test_filter_targets(): + parent = SquareModel.objects.get(pk='grid1_square11BeZGKcjKyUKT3nQh') + filtered,filtered_set = filter_targets(parent) + assert len(filtered_set) == 5 + +def test_randomized_choice(): + filtered_set = {1,2,3,4,5} + choices = randomized_choice(filtered_set.copy(), 3) + assert len(choices) == 3 + assert len(set(choices)) == len(choices) + + choices = randomized_choice(filtered_set.copy(), 12) + print(f'Choices are {choices}') + assert len(choices) == 12 + assert filtered_set.issubset(set(choices)) + choices = choices[10:] + assert len(choices) == 2 + assert len(set(choices)) == len(choices) + +def test_choose_get_index(): + lst = [1,2,2,3,4,4,4,5] + init_len = len(lst) + value = 4 + choice = choose_get_index(lst, value) + assert choice in [4,5,6] + assert len(lst) == init_len - 1 \ No newline at end of file diff --git a/Smartscope/core/tests/test_selector_sorter.py b/Smartscope/core/tests/test_selector_sorter.py new file mode 100644 index 00000000..5f71670c --- /dev/null +++ b/Smartscope/core/tests/test_selector_sorter.py @@ -0,0 +1,77 @@ +import pytest +from Smartscope.bin import smartscope +from pathlib import Path + +from ..models import HoleModel + +from ..selector_sorter import SelectorSorter, SelectorSorterData, SelectorValueParser + +def test_selector_sorter_init(): + sorter = SelectorSorter('selector_name', 5, [0,1]) + assert sorter.selector_name == 'selector_name' + assert sorter._n_classes == 5 + assert sorter._fractional_limits == [0,1] + +def test_selector_sorter_data_init(): + data = SelectorSorterData(selector_name='selector_name', low_limit=0, high_limit=1) + assert data.selector_name == 'selector_name' + assert data.low_limit == 0 + assert data.high_limit == 1 + +def test_selector_value_parser_init(): + parser = SelectorValueParser('selector_name', from_server=True) + assert parser._selector_name == 'selector_name' + assert parser._from_server == True + + parser = SelectorValueParser('selector_name', from_server=False) + assert parser._selector_name == 'selector_name' + assert parser._from_server == False + +def test_selector_value_parsing(): + parser = SelectorValueParser('Graylevel selector', from_server=True) + targets = HoleModel.display.filter(grid_id='1grid1OkinLNkbO4t1d8zuPKk3R5KE') + assert all([isinstance(x, float) for x in parser.extract_values(targets)]) + +def test_selector_sorter_calculate_classes(): + n_classes = 5 + sorter = SelectorSorter('Graylevel selector', n_classes, [0,1]) + with pytest.raises(Exception): + sorter.calculate_classes() + + sorter.values = [1,2,3,4,5,6,7,8,9,10] + assert sorter.limits[0] == 1 + assert sorter.limits[-1] == 10 + sorter.calculate_classes() + assert len(set(sorter._classes)) == 5 + + sorter.limits = [3,9] + assert sorter._classes is None + sorter.calculate_classes() + assert len(set(sorter._classes)) == 6 + +def test_selector_data_parse_sorter(): + sorter = SelectorSorter('Graylevel selector', 5, [0,1]) + sorter.values = [1,2,3,4,5,6,7,8,9,10] + + sorter_data = SelectorSorterData.parse_sorter(sorter) + assert sorter_data.selector_name == 'Graylevel selector' + assert sorter_data.low_limit == 1 + assert sorter_data.high_limit == 10 + +def test_selector_data_create_sorter(): + sorter_data = SelectorSorterData(selector_name='Graylevel selector', low_limit=0, high_limit=10) + sorter = sorter_data.create_sorter() + assert sorter.selector_name == 'Graylevel selector' + assert sorter._n_classes == 5 + assert sorter.limits[0] == 0 + assert sorter.limits[-1] == 10 + +def test_selector_data_save_load(): + test_dir = Path('Smartscope/core/tests') + sorter_data = SelectorSorterData(selector_name='Graylevel selector', low_limit=0, high_limit=10) + sorter_data.save(test_dir) + sorter_data = SelectorSorterData.load(test_dir, 'Graylevel selector') + assert sorter_data.selector_name == 'Graylevel selector' + assert sorter_data.low_limit == 0 + assert sorter_data.high_limit == 10 + sorter_data.delete(test_dir) \ No newline at end of file diff --git a/Smartscope/lib/Datatypes/base_plugin.py b/Smartscope/lib/Datatypes/base_plugin.py index 87a052dc..756a5271 100644 --- a/Smartscope/lib/Datatypes/base_plugin.py +++ b/Smartscope/lib/Datatypes/base_plugin.py @@ -32,6 +32,7 @@ class BaseFeatureAnalyzer(BaseModel, ABC): reference: Optional[str]= '' method: Optional[str] = '' module: Optional[str] = '' + draw_method: Optional[str] = None kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) importPaths: Union[str,List] = Field(default_factory=list) @@ -85,6 +86,7 @@ class Selector(BaseFeatureAnalyzer): clusters: Dict[(str, Any)] exclude: List[str] = Field(default_factory=list) target_class: str = TargetClass.SELECTOR + limits: List[float] = [0.0,1.0] kwargs: Dict[str, Any] = Field(default_factory=dict) def get_label(self, label): diff --git a/Smartscope/lib/Datatypes/grid_geometry.py b/Smartscope/lib/Datatypes/grid_geometry.py new file mode 100644 index 00000000..c8e6d934 --- /dev/null +++ b/Smartscope/lib/Datatypes/grid_geometry.py @@ -0,0 +1,48 @@ +from typing import Optional, List +from enum import Enum +from pathlib import Path +import json +import logging +from pydantic import BaseModel, Field + + +logger = logging.getLogger(__name__) + +class GridGeometryLevel(Enum): + ATLAS = 'square_mesh' + SQUARE = 'hole_square' + MEDMAG = 'hole_medmag' + +class GridGeometry(BaseModel): + square_mesh_spacing: Optional[float] = None + square_mesh_rotation: Optional[float] = None + hole_square_spacing: Optional[float] = None + hole_square_rotation: Optional[float] = None + hole_medmag_spacing: Optional[float] = None + hole_medmag_rotation: Optional[float] = None + + @classmethod + def load(cls, directory:str): + file = Path(directory) / 'grid_geometry.json' + if not file.exists(): + logging.debug(f'No grid_geometry.json found in {directory}, creating a fresh one.') + return cls() + logging.debug(f'Loading geometry from {str(file)}.') + json_file = json.loads(file.read_text()) + return cls.model_validate(json_file) + + def save(self, directory:str): + file = Path(directory) / 'grid_geometry.json' + file.write_text(json.dumps(self.model_dump(),indent=4)) + + def get_geometry(self, level:GridGeometryLevel): + spacing = getattr(self, f'{level.value}_spacing') + rotation = getattr(self, f'{level.value}_rotation') + if any([spacing is None, rotation is None]): + logger.warning(f'No {level.value} geometry found.') + return rotation, spacing + + def set_geometry(self, level:GridGeometryLevel, spacing:float, rotation:float): + setattr(self, f'{level.value}_spacing', spacing) + setattr(self, f'{level.value}_rotation', rotation) + logger.info(f'Set {level.value} geometry to spacing: {spacing} and rotation: {rotation}.') \ No newline at end of file diff --git a/Smartscope/lib/Finders/AIFinder/wrapper.py b/Smartscope/lib/Finders/AIFinder/wrapper.py index 69017670..9a2ae709 100755 --- a/Smartscope/lib/Finders/AIFinder/wrapper.py +++ b/Smartscope/lib/Finders/AIFinder/wrapper.py @@ -57,7 +57,7 @@ def find_holes(montage:Montage, **kwargs): logger.debug(f'{holes[0]},{type(holes[0])}') holes = [(np.array(hole)-np.array(list(montage.center)*2))*binning + np.array(list(montage.center)*2) for hole in holes] - logger.debug(f'{holes[0]},{type(holes[0])}') + # logger.debug(f'{holes[0]},{type(holes[0])}') return holes, success, dict() diff --git a/Smartscope/lib/Finders/basic_finders.py b/Smartscope/lib/Finders/basic_finders.py index 8b73423c..8544638a 100755 --- a/Smartscope/lib/Finders/basic_finders.py +++ b/Smartscope/lib/Finders/basic_finders.py @@ -248,3 +248,9 @@ def find_square_center(img): cX = int(M["m10"] / M["m00"]) cY = int(M["m01"] / M["m00"]) return np.array([cX, cY]) + +def create_square_mask(image): + cnts, center, _ = find_square(image) + mask = np.zeros(image.shape) + cv2.drawContours(mask,[cnts],-1,1,cv2.FILLED) + return mask diff --git a/Smartscope/lib/Finders/lattice_extension.py b/Smartscope/lib/Finders/lattice_extension.py new file mode 100644 index 00000000..943527b7 --- /dev/null +++ b/Smartscope/lib/Finders/lattice_extension.py @@ -0,0 +1,19 @@ +import numpy as np +import logging +from ..mesh_operations import generate_rotated_grid, calculate_translation, remove_indices, filter_oob +from .basic_finders import create_square_mask + +logger = logging.getLogger(__name__) + +def lattice_extension(input_lattice:np.ndarray, image:np.ndarray, rotation:float, spacing:float): + points = generate_rotated_grid(spacing, rotation, image.shape) + translation, _ = calculate_translation(input_lattice,points.T) + translated = points + translation.reshape(2,-1) + translation, min_idx = calculate_translation(input_lattice,translated.T) + translated = remove_indices(translated.T, min_idx) + translated = filter_oob(translated,image.shape) + translated = translated.astype(int) + mask = create_square_mask(image=image) + filtered = translated[np.where(mask[translated[:,1],translated[:,0]] == 1)] + return filtered + # return filtered, True, dict(spacing=spacing, rotation=rotation, translation=translation) diff --git a/Smartscope/lib/external_process.py b/Smartscope/lib/external_process.py index 61342bd6..fd4ca601 100644 --- a/Smartscope/lib/external_process.py +++ b/Smartscope/lib/external_process.py @@ -27,7 +27,7 @@ def align_frames( software: alignframes in IMOD used by process_hm_from_frames() ''' - com = f'alignframes -input {frames} -output {output_file} -rotation -1 ' + \ + com = f'alignframes -mem 8 -input {frames} -output {output_file} -rotation -1 -pair -2 ' + \ f'-dfile {mdoc} -volt {voltage} -plottable {output_shifts}' if gain is not None: com += f' -gain {gain}' @@ -68,7 +68,7 @@ def CTFfind( output_file = os.path.join(output_directory, 'ctf.mrc') # interactive mode required by ctffind inputs = [input_mrc, output_file, pixel_size, voltage, spherical_abberation,\ - 0.1, 512, 30, 10, 5000, 50000, 200, 'no','no','no','no','no'] + 0.1, 512, 30, 10, 5000, 40000, 200, 'no','no','no','no','no','yes','no','no','30','7','no','no','no'] inputs = '\n'.join([str(i) for i in inputs]) # f'{input_mrc}\n{output_file}\n{pixel_size}\n{voltage}\n{spherical_abberation}\n0.1\n512\n30\n10\n5000\n50000\n200\nno\nno\nno\nno\nno', p = subprocess.run( diff --git a/Smartscope/lib/mesh_operations.py b/Smartscope/lib/mesh_operations.py new file mode 100644 index 00000000..be7c6a23 --- /dev/null +++ b/Smartscope/lib/mesh_operations.py @@ -0,0 +1,95 @@ +import numpy as np +import logging +from scipy.spatial.distance import cdist + +logger = logging.getLogger(__name__) + +def filter_closest(points,max_dist): + distances = cdist(points,points) + out_points = [] + dists = [] + for ind,row in enumerate(distances): + indexes= [i[1] for i in np.argwhere([row > 0, row < max_dist]) if i[0] ==1 and i[1] != ind] + filtered = points[indexes,:] - points[ind] + out_points.extend(filtered) + dists.extend(row.copy()[indexes]) + return np.array(out_points), np.median(dists) + +def generate_rotated_grid(spacing, angle, shape): + # Generate a grid of points with spacing + x = np.arange(-shape[1],shape[1], spacing) # Adjust the range as needed + y = np.arange(-shape[0],shape[0], spacing) + xx, yy = np.meshgrid(x, y) + + # Flatten the grid of points + points = np.vstack([xx.ravel(), yy.ravel()]) + + # Define rotation matrix + theta = np.radians(angle) + c, s = np.cos(theta), np.sin(theta) + rotation_matrix = np.array(((c, -s), (s, c))) + + # Rotate the points + rotated_points = np.dot(rotation_matrix, points) + + coordinates = rotated_points + np.array([[shape[1]//2],[shape[0]//2]]) + + return coordinates + + +def remove_indices(array, indices): + # Create a boolean mask with True values for indices to remove + mask = np.ones(len(array), dtype=bool) + mask[indices] = False + + # Use boolean indexing to remove the specified indices + result = array[mask] + + return result + + +def calculate_translation(lattice1, lattice2): + # Calculate pairwise Euclidean distances between all points in the lattices + distances = cdist(lattice1, lattice2) + min_idx = np.argmin(distances,axis=1) + # Find the indices of the minimum distance + # min_index = np.unravel_index(np.argmin(distances,axis=1), distances.shape) + # Calculate the translation based on the difference between the indices of minimum distance + translation = lattice1 - lattice2[min_idx] + + return np.median(translation,axis=0), min_idx + + +def filter_oob(coordinates, shape): + x_coords, y_coords = coordinates[:,0], coordinates[:,1] + # print(x_coords, y_coords) + # Create boolean masks for x and y coordinates within specified ranges + x_mask = np.logical_and(x_coords >= 0, x_coords < shape[1]) + y_mask = np.logical_and(y_coords >= 0, y_coords < shape[0]) + + # Use boolean indexing to filter coordinates based on the masks + filtered_coordinates = coordinates[np.logical_and(x_mask, y_mask),:] + + return filtered_coordinates + +def atan2_firstquad(point): + angle = np.degrees(np.arctan2(point[1],point[0])) + while angle > 90: + angle-=90 + while angle < 0: + angle += 90 + return angle + +def get_average_angle(points): + angles = np.apply_along_axis(atan2_firstquad,axis=1, arr=points) + return np.mean(angles) + + +def get_mesh_rotation_spacing(targets, mesh_spacing_in_pixels): + # grid = AutoloaderGrid.objects.get(pk=grid_id) + # print(f'Finding points within {mesh_spacing_in_pixels} pixels.') + filtered_points, spacing= filter_closest(targets, mesh_spacing_in_pixels*1.08) + logger.debug(f'Calculated mean spacing: {spacing} pixels') + rotation = get_average_angle(filtered_points) + logger.debug(f'Calculated mesh rotation: {rotation} degrees') + return rotation, spacing \ No newline at end of file diff --git a/Smartscope/lib/preprocessing_methods.py b/Smartscope/lib/preprocessing_methods.py index e6853122..b69a625b 100755 --- a/Smartscope/lib/preprocessing_methods.py +++ b/Smartscope/lib/preprocessing_methods.py @@ -41,14 +41,14 @@ def get_CTFFIN4_data(ctf_text: Path) -> List[float]: ''' -def get_CTFFIN4_data(ctf_text: Path) -> List[float]: +def get_CTFFIND5_data(ctf_text: Path) -> List[float]: ''' get results from ctf_*.txt determined by ctffinder args: ''' logger.info(f"Try to read CTF file {ctf_text}") ctf={} - columns=['l', 'df1', 'df2', 'angast', 'phshift', 'cc', 'ctffit'] + columns=['l', 'df1', 'df2', 'angast', 'phshift', 'cc', 'ctffit','tilt_axis_angle','tilt_angle','ice_thickness'] with open(ctf_text, 'r') as f: for line in f: if not line.startswith('#'): @@ -61,6 +61,9 @@ def get_CTFFIN4_data(ctf_text: Path) -> List[float]: 'astig': ctf['df1'] - ctf['df2'], 'angast': ctf['angast'], 'ctffit': ctf['ctffit'], + 'tilt_axis_angle': ctf['tilt_axis_angle'], + 'tilt_angle': ctf['tilt_angle'], + 'ice_thickness': int(round(ctf['ice_thickness']/10)) } def process_hm_from_frames( @@ -188,6 +191,13 @@ def process_hm_from_average( ) return montage +def clear_queue(queue): + logger.info(f'Clearing queue') + queue.task_done() + while not queue.empty(): + item = queue.get() + logger.info(f'Got item={item} from queue') + queue.task_done() def processing_worker_wrapper(logdir, queue, output_queue=None): logger.info(f"processing worker: {logdir}\t{queue}\t{output_queue}") @@ -202,8 +212,8 @@ def processing_worker_wrapper(logdir, queue, output_queue=None): item = queue.get() logger.info(f'Got item={item} from queue') if item == 'exit': - queue.task_done() logger.info('Breaking processing worker loop.') + clear_queue(queue) break if item is not None: logger.debug(f'Running {item[0]} {item[1]} {item[2]} from queue') @@ -218,5 +228,8 @@ def processing_worker_wrapper(logdir, queue, output_queue=None): except Exception as e: logger.error("Error in the processing worker") logger.exception(e) + clear_queue(queue) except KeyboardInterrupt as e: logger.info('SIGINT recieved by the processing worker') + clear_queue(queue) + diff --git a/Smartscope/lib/tests/__init__.py b/Smartscope/lib/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Smartscope/server/api/migrations/0018_classifier_created_at_finder_created_at_and_more.py b/Smartscope/server/api/migrations/0018_classifier_created_at_finder_created_at_and_more.py new file mode 100644 index 00000000..15b9b2d1 --- /dev/null +++ b/Smartscope/server/api/migrations/0018_classifier_created_at_finder_created_at_and_more.py @@ -0,0 +1,40 @@ +# Generated by Django 4.2.2 on 2024-02-06 03:31 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('API', '0017_alter_customgrouppath_group_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='classifier', + name='created_at', + field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), + preserve_default=False, + ), + migrations.AddField( + model_name='finder', + name='created_at', + field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), + preserve_default=False, + ), + migrations.AddField( + model_name='selector', + name='created_at', + field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), + preserve_default=False, + ), + migrations.AlterField( + model_name='screeningsession', + name='user', + field=models.ForeignKey(default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL, to_field='username'), + ), + ] diff --git a/Smartscope/server/api/migrations/0019_highmagmodel_ice_thickness_highmagmodel_tilt_angle_and_more.py b/Smartscope/server/api/migrations/0019_highmagmodel_ice_thickness_highmagmodel_tilt_angle_and_more.py new file mode 100644 index 00000000..860d3827 --- /dev/null +++ b/Smartscope/server/api/migrations/0019_highmagmodel_ice_thickness_highmagmodel_tilt_angle_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.2 on 2024-03-04 19:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('API', '0018_classifier_created_at_finder_created_at_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='highmagmodel', + name='ice_thickness', + field=models.IntegerField(null=True), + ), + migrations.AddField( + model_name='highmagmodel', + name='tilt_angle', + field=models.FloatField(null=True), + ), + migrations.AddField( + model_name='highmagmodel', + name='tilt_axis_angle', + field=models.FloatField(null=True), + ), + ] diff --git a/Smartscope/server/api/serializers/__init__.py b/Smartscope/server/api/serializers/__init__.py new file mode 100644 index 00000000..4cb0e527 --- /dev/null +++ b/Smartscope/server/api/serializers/__init__.py @@ -0,0 +1,2 @@ +from .export_serializers import * +from .serializers import * \ No newline at end of file diff --git a/Smartscope/server/api/export_serializers.py b/Smartscope/server/api/serializers/export_serializers.py similarity index 55% rename from Smartscope/server/api/export_serializers.py rename to Smartscope/server/api/serializers/export_serializers.py index f75270ab..a6f97a26 100644 --- a/Smartscope/server/api/export_serializers.py +++ b/Smartscope/server/api/serializers/export_serializers.py @@ -1,6 +1,9 @@ -from rest_framework.serializers import ModelSerializer +from typing import List, Tuple, Union +from rest_framework.serializers import ModelSerializer, ListSerializer +from rest_framework import serializers as drf_serializers from Smartscope.core import models -from Smartscope.server.api.serializers import GridCollectionParamsSerializer, MicroscopeSerializer,DetectorSerializer +from .serializers import GridCollectionParamsSerializer, MicroscopeSerializer,DetectorSerializer, SquareSerializer, HoleSerializer, HighMagSerializer, AtlasSerializer +from .utils import extract_targets, create_target_label_instances from django.contrib.contenttypes.models import ContentType from django.db import transaction import logging @@ -21,7 +24,6 @@ class DetailedDetectorSerializer(ModelSerializer): class Meta: model= models.Detector fields = '__all__' - # exclude = ['detector_id'] class DetailedSessionSerializer(ModelSerializer): @@ -53,9 +55,54 @@ class Meta: exclude = ['id', 'object_id', 'content_type'] class TargetSerializer(ModelSerializer): + name = drf_serializers.CharField(required=False) + finders = FinderSerializer(many=True) - selectors = SelectorSerializer(many=True) - classifiers = ClassifierSerializer(many=True) + selectors = SelectorSerializer(many=True, required=False) + classifiers = ClassifierSerializer(many=True, required=False) + + class Config: + id_alias:str = 'NotImplemented' + target_model:models.BaseModel = 'NotImplemented' + parent_model: models.BaseModel = 'NotImplemented' + parent_id_alias:str = 'NotImplemented' + + def validate(self, attrs): + return super().validate(attrs) + + def create(self,validated_data, label_types: Union[List, '__all__']= '__all__'): + labels = [] + uid = validated_data.pop('uid') + target_labels, validated_data, _ = extract_targets(validated_data, label_types= label_types) + + if uid is None: + grid_id = models.AutoloaderGrid.objects.get(grid_id=validated_data.pop('grid_id')) + parent_id = self.Config.parent_model.objects.get(pk=validated_data.pop(self.Config.parent_id_alias)) + instance = self.Meta.model(**validated_data, grid_id=grid_id, **{self.Config.parent_id_alias:parent_id}) + logger.debug(f'Created new instance with uid: {instance.pk}') + else: + instance = self.Meta.model.objects.get(pk=uid) + logger.debug(f'Working on target with uid: {instance.pk}') + + labels += create_target_label_instances(target_labels,instance.pk,ContentType.objects.get_for_model(self.Config.target_model)) + + return instance, labels + + +class AddTargetsListSerializer(ListSerializer): + + def create(self, validated_data, label_types = '__all__',): + all_targets = [] + all_labels = [] + for target in validated_data: + instance, labels = self.child.create(target, label_types=label_types) + all_targets.append(instance) + all_labels += labels + return all_targets, all_labels + + def update(self, validated_data): + for target in validated_data: + self.child.update(target) class DetailedHighMagSerializer(TargetSerializer): @@ -71,17 +118,61 @@ class Meta: model = models.HoleModel exclude = ['hole_id','square_id','grid_id'] + +class DetailedFullHoleSerializer(TargetSerializer): + targets = DetailedHighMagSerializer(many=True, required=False) + + class Meta: + model = models.HoleModel + fields = '__all__' + list_serializer_class = AddTargetsListSerializer + + class Config: + id_alias:str = 'hole_id' + target_model:models.BaseModel = models.HoleModel + parent_model: models.BaseModel = models.SquareModel + parent_id_alias:str = 'square_id' + class ScipionPluginHoleSerializer(DetailedHoleSerializer): class Meta(DetailedHoleSerializer.Meta): exclude = [] class DetailedSquareSerializer(TargetSerializer): - targets = DetailedHoleSerializer(many=True) + targets = DetailedHoleSerializer(many=True, ) class Meta: model = models.SquareModel exclude = ['square_id','atlas_id','grid_id'] + + +class DetailedFullSquareSerializer(TargetSerializer): + targets = DetailedFullHoleSerializer(many=True, required=False) + + class Meta: + model = models.SquareModel + fields = '__all__' + list_serializer_class = AddTargetsListSerializer + + class Config: + id_alias:str = 'square_id' + target_model:models.BaseModel = models.SquareModel + parent_model: models.BaseModel = models.AtlasModel + parent_id_alias:str = 'atlas_id' + +class DetailedNoTargetSquareSerializer(TargetSerializer): + # targets = DetailedHoleSerializer(many=True, required=False) + + class Meta: + model = models.SquareModel + fields = '__all__' + list_serializer_class = AddTargetsListSerializer + + class Config: + id_alias:str = 'square_id' + target_model:models.BaseModel = models.SquareModel + parent_model: models.BaseModel = models.AtlasModel + parent_id_alias:str = 'atlas_id' class DetailedAtlasSerializer(ModelSerializer): targets = DetailedSquareSerializer(many=True) @@ -90,19 +181,17 @@ class Meta: model = models.AtlasModel exclude = ['atlas_id','grid_id'] -def extract_targets(data): - target_labels= [] - target_labels += [(item,models.Finder) for item in data.pop('finders',[])] - target_labels += [(item,models.Classifier) for item in data.pop('classifiers',[])] - target_labels += [(item,models.Selector) for item in data.pop('selectors',[])] - targets = data.pop('targets',[]) - return target_labels, data, targets +class DetailedFullAtlasSerializer(ModelSerializer): + targets = DetailedNoTargetSquareSerializer(many=True) + + class Meta: + model = models.AtlasModel + fields = '__all__' + + class Config: + id_alias:str = 'atlas_id' + target_model:models.BaseModel = models.SquareModel -def create_target_label_instances(target_labels,instance,content_type): - target_labels_models = [] - for label,label_class in target_labels: - target_labels_models.append(label_class(**label,object_id=instance,content_type=content_type)) - return target_labels_models class ExportMetaSerializer(ModelSerializer): atlas = DetailedAtlasSerializer(many=True) @@ -159,6 +248,4 @@ def create(self,validated_data): [target.save() for target in target_models] [label.save() for label in target_labels_models] - return grid_model - - + return grid_model \ No newline at end of file diff --git a/Smartscope/server/api/serializers.py b/Smartscope/server/api/serializers/serializers.py similarity index 87% rename from Smartscope/server/api/serializers.py rename to Smartscope/server/api/serializers/serializers.py index a9fe3e79..5c02638a 100755 --- a/Smartscope/server/api/serializers.py +++ b/Smartscope/server/api/serializers/serializers.py @@ -12,6 +12,10 @@ from Smartscope.core.models.atlas import AtlasModel from Smartscope.core.models.square import SquareModel from Smartscope.core.models.high_mag import HighMagModel +from Smartscope.core.models.target_label import Classifier +from Smartscope.core.selector_sorter import SelectorSorter, LagacySorterError, SelectorValueParser, initialize_selector +from Smartscope.core.settings.worker import PLUGINS_FACTORY +from Smartscope.core.svg_plots import drawAtlasNew # from Smartscope.lib.storage.smartscope_storage import SmartscopeStorage from Smartscope.lib.converters import * import logging @@ -49,6 +53,12 @@ class Meta: model = ScreeningSession fields = '__all__' +class ClassifierSerializer(RESTserializers.ModelSerializer): + + class Meta: + model = Classifier + exclude = ['id'] + class AutoloaderGridSerializer(RESTserializers.ModelSerializer): @@ -267,13 +277,27 @@ def load_meta(self): return dict() return update_to_fullmeta(targets) + def svg(self): + if self.display_type == 'selectors': + try: + # plugin = PLUGINS_FACTORY[self.method] + # selector_data = SelectorValueParser(self.method, from_server=True) + # sorter = SelectorSorter(self.method, fractional_limits=plugin.limits) + # sorter.values = selector_data.extract_values(self.instance.targets) + sorter = initialize_selector(self.instance.grid_id, self.method, self.instance.targets) + return drawAtlasNew(self.instance, sorter).as_svg() + except LagacySorterError: + logger.warning('Lagacy sorter error. Reverting to lagacy sorting.') + return self.instance.svg(display_type=self.display_type, method=self.method,).as_svg() + + def to_representation(self, instance): return { 'type': 'reload', 'display_type': self.display_type, 'method': self.method, 'element': models_to_serializers[self.instance.__class__.__name__]['element'], - 'svg': self.instance.svg(display_type=self.display_type, method=self.method,).as_svg(), + 'svg': self.svg() # 'fullmeta': self.load_meta() } diff --git a/Smartscope/server/api/serializers/utils.py b/Smartscope/server/api/serializers/utils.py new file mode 100644 index 00000000..d438183f --- /dev/null +++ b/Smartscope/server/api/serializers/utils.py @@ -0,0 +1,28 @@ +from typing import List +from Smartscope.core import models +import logging + +logger = logging.getLogger(__name__) + + +def extract_targets(data, label_types:List[str]=['__all__']): + if label_types == ['__all__']: + logger.debug(f'Label types: {label_types}. Using all available label types') + label_types = ['finders','classifiers','selectors'] + target_labels= dict() + output_labels = [] + target_labels['finders'] = [(item,models.Finder) for item in data.pop('finders',[])] + target_labels['classifiers'] = [(item,models.Classifier) for item in data.pop('classifiers',[])] + target_labels['selectors'] = [(item,models.Selector) for item in data.pop('selectors',[])] + targets = data.pop('targets',[]) + for label_type in label_types: + output_labels += target_labels[label_type] + return output_labels, data, targets + +def create_target_label_instances(target_labels,instance,content_type): + target_labels_models = [] + # logger.info(f'Creating target labels for {instance}') + # logger.debug(f'Target_labels: \n{target_labels}') + for label,label_class in target_labels: + target_labels_models.append(label_class(**label,object_id=instance,content_type=content_type)) + return target_labels_models \ No newline at end of file diff --git a/Smartscope/server/api/templates/holecard.html b/Smartscope/server/api/templates/holecard.html index fa660bae..027da766 100755 --- a/Smartscope/server/api/templates/holecard.html +++ b/Smartscope/server/api/templates/holecard.html @@ -8,7 +8,7 @@