From ef6a8955f201f669d4c8dbed50a0201dbc44afe9 Mon Sep 17 00:00:00 2001 From: James Parkhurst Date: Wed, 22 Nov 2023 15:44:38 +0000 Subject: [PATCH] Refactor sim (#32) * Added calibrate a command line program * Update calibrate * Update sample margin * Updating average particles * Merging implementation from old commits * Fix indent * Fixed average_all_particles * Fixed test * Improving inelastic * Keep mean defocus * Cache Landau distributin * Update calibrate * Starting to refactor * Split out multem stuff into a single file * Update docstirngs * Fixed typo * Removed unused import * Updating simulation * Add type hints * Update workflow * Fixed failures * Blackened * Add iterable * Fixing * Fixing * Fixing * Fixed * Fixing * Fixing diffraction and exit wave etc * Using calibrate from master * Checking optics is ok * Blackened * Fixed import * Fixed some typing issues * Fixed beam spread in energy bins * Fixed check * Fixed bug * Moved ice parameters to simulate engine * Fixed extract particles * Change assert * Improve documentation * Added assert * Refactor phase plate sim * Fixed thickness * Fixed test * Add a margin * Add margin to cylinder * Removing assertr * Fixed scan * Add an assert for voxel size * Change default ice density * Update plots * Setting the GPU id * FIxed test * Fixed typing * Fixed test * Update pybind --------- Co-authored-by: James Parkhurst --- setup.py | 2 + src/parakeet/analyse/__init__.py | 1 - src/parakeet/analyse/_average_particles.py | 55 +- src/parakeet/analyse/_extract.py | 208 +++---- src/parakeet/command_line/analyse/__init__.py | 1 + .../analyse/_average_all_particles.py | 15 +- .../analyse/_average_extracted_particles.py | 116 ++++ src/parakeet/config.py | 8 +- src/parakeet/inelastic.py | 124 ++++- src/parakeet/landau.py | 24 +- src/parakeet/sample/__init__.py | 27 +- src/parakeet/sample/distribute.py | 17 +- src/parakeet/scan.py | 2 +- src/parakeet/simulate/_cbed.py | 132 +---- src/parakeet/simulate/_ctf.py | 30 +- src/parakeet/simulate/_exit_wave.py | 139 +---- src/parakeet/simulate/_image.py | 3 - src/parakeet/simulate/_optics.py | 351 +++++------- src/parakeet/simulate/_potential.py | 45 +- src/parakeet/simulate/_simple.py | 41 +- src/parakeet/simulate/engine.py | 513 ++++++++++++++++++ src/parakeet/simulate/phase_plate.py | 14 +- src/parakeet/simulate/simulation.py | 348 +----------- src/parakeet/util/calibrate_ice_model.py | 40 +- tests/test_inelastic.py | 17 +- 25 files changed, 1223 insertions(+), 1050 deletions(-) create mode 100644 src/parakeet/command_line/analyse/_average_extracted_particles.py create mode 100644 src/parakeet/simulate/engine.py diff --git a/setup.py b/setup.py index 35527ed9..ee898db6 100644 --- a/setup.py +++ b/setup.py @@ -104,10 +104,12 @@ def main(): "parakeet.metadata.export=parakeet.command_line.metadata:export", "parakeet.analyse.reconstruct=parakeet.command_line.analyse:reconstruct", "parakeet.analyse.average_particles=parakeet.command_line.analyse:average_particles", + "parakeet.analyse.average_extracted_particles=parakeet.command_line.analyse:average_extracted_particles", "parakeet.analyse.average_all_particles=parakeet.command_line.analyse:average_all_particles", "parakeet.analyse.extract=parakeet.command_line.analyse:extract", "parakeet.analyse.refine=parakeet.command_line.analyse:refine", "parakeet.analyse.correct=parakeet.command_line.analyse:correct", + "dev.parakeet.calibrate_ice_model=parakeet.util.calibrate_ice_model:main", ] }, extras_require={ diff --git a/src/parakeet/analyse/__init__.py b/src/parakeet/analyse/__init__.py index 72c6941b..7079c317 100644 --- a/src/parakeet/analyse/__init__.py +++ b/src/parakeet/analyse/__init__.py @@ -2,7 +2,6 @@ from parakeet.analyse._correct import * # noqa from parakeet.analyse._reconstruct import * # noqa from parakeet.analyse._average_particles import * # noqa -from parakeet.analyse._average_particles import * # noqa from parakeet.analyse._extract import * # noqa from parakeet.analyse._refine import * # noqa # fmt: on diff --git a/src/parakeet/analyse/_average_particles.py b/src/parakeet/analyse/_average_particles.py index e7cbdce1..b0a2758c 100644 --- a/src/parakeet/analyse/_average_particles.py +++ b/src/parakeet/analyse/_average_particles.py @@ -119,7 +119,6 @@ def _iterate_particles( print("Getting sub tomogram") sub_tomo = tomogram[x0[1] : x1[1], x0[2] : x1[2], x0[0] : x1[0]] if sub_tomo.shape == half_shape[-3:]: - print("YIELD") yield (sub_tomo, offset, orientation, j) @@ -262,6 +261,8 @@ def _average_particles_Config( assert len(positions) == len(orientations) if num_particles <= 0: num_particles = len(positions) + else: + num_particles = min(num_particles, len(positions)) print( "Averaging %d %s particles with box size %d" % (num_particles, name, length) ) @@ -310,10 +311,6 @@ def _average_particles_Config( if num[1] > 0: half[1, :, :, :] = half[1, :, :, :] / num[1] - # from matplotlib import pylab - # pylab.imshow(average[half_length, :, :]) - # pylab.show() - # Save the averaged data print("Saving half 1 to %s" % half_1_filename) handle = mrcfile.new(half_1_filename, overwrite=True) @@ -332,6 +329,7 @@ def average_all_particles( rec_file: str, average_file: str, particle_size: int, + num_particles: int, ): """ Perform sub tomogram averaging @@ -356,7 +354,7 @@ def average_all_particles( # Do the sub tomogram averaging _average_all_particles_Config( - config.scan, sample, rec_file, average_file, particle_size + config.scan, sample, rec_file, average_file, particle_size, num_particles ) @@ -367,48 +365,13 @@ def _average_all_particles_Config( rec_filename: str, average_filename: str, particle_size: int = 0, + num_particles: int = 0, ): """ Average particles to compute averaged reconstruction """ - def rotate_array(data, rotation, offset): - # Create the pixel indices - az = np.arange(data.shape[0]) - ay = np.arange(data.shape[1]) - ax = np.arange(data.shape[2]) - x, y, z = np.meshgrid(az, ay, ax, indexing="ij") - - # Create a stack of coordinates - xyz = np.vstack( - [ - x.reshape(-1) - offset[0], - y.reshape(-1) - offset[1], - z.reshape(-1) - offset[2], - ] - ).T - - # create transformation matrix - r = scipy.spatial.transform.Rotation.from_rotvec(rotation) - - # apply transformation - transformed_xyz = r.apply(xyz) - - # extract coordinates - x = transformed_xyz[:, 0] + offset[0] - y = transformed_xyz[:, 1] + offset[1] - z = transformed_xyz[:, 2] + offset[2] - - # Reshape - x = x.reshape(data.shape) - y = y.reshape(data.shape) - z = z.reshape(data.shape) - - # sample - result = scipy.ndimage.map_coordinates(data, [x, y, z], order=1) - return result - # Get the scan config scan = config.dict() @@ -463,14 +426,17 @@ def rotate_array(data, rotation, offset): half_length = particle_size // 2 length = 2 * half_length assert len(positions) == len(orientations) - num_particles = len(positions) + if num_particles <= 0: + num_particles = len(positions) + else: + num_particles = min(num_particles, len(positions)) print( "Averaging %d %s particles with box size %d" % (num_particles, name, length) ) # Create the average array average = np.zeros(shape=(length, length, length), dtype="float32") - num = 0 + num = 0.0 # Sort the positions and orientations by y positions, orientations = zip( @@ -500,6 +466,7 @@ def rotate_array(data, rotation, offset): # Add the contribution to the average average += data num += 1 + print("Count: ", num) # Average the sub tomograms print("Averaging map with %d particles" % num) diff --git a/src/parakeet/analyse/_extract.py b/src/parakeet/analyse/_extract.py index e731d34a..ec2e18b7 100644 --- a/src/parakeet/analyse/_extract.py +++ b/src/parakeet/analyse/_extract.py @@ -8,19 +8,20 @@ # This code is distributed under the GPLv3 license, a copy of # which is included in the root directory of this package. # +import concurrent.futures import numpy as np import mrcfile import random import h5py -import scipy.ndimage -import scipy.spatial.transform import parakeet.sample -from typing import Any from functools import singledispatch from math import sqrt, ceil +from parakeet.analyse._average_particles import lazy_map +from parakeet.analyse._average_particles import _process_sub_tomo +from parakeet.analyse._average_particles import _iterate_particles -__all__ = ["extract"] +__all__ = ["extract", "average_extracted_particles"] # Set the random seed @@ -64,7 +65,7 @@ def extract( def _extract_Config( config: parakeet.config.Config, sample: parakeet.sample.Sample, - rec_file: str, + rec_filename: str, extract_file: str, particle_size: int = 0, ): @@ -73,49 +74,14 @@ def _extract_Config( """ - def rotate_array(data, rotation, offset): - # Create the pixel indices - az = np.arange(data.shape[0]) - ay = np.arange(data.shape[1]) - ax = np.arange(data.shape[2]) - x, y, z = np.meshgrid(az, ay, ax, indexing="ij") - - # Create a stack of coordinates - xyz = np.vstack( - [ - x.reshape(-1) - offset[0], - y.reshape(-1) - offset[1], - z.reshape(-1) - offset[2], - ] - ).T - - # create transformation matrix - r = scipy.spatial.transform.Rotation.from_rotvec(rotation) - - # apply transformation - transformed_xyz = r.apply(xyz) - - # extract coordinates - x = transformed_xyz[:, 0] + offset[0] - y = transformed_xyz[:, 1] + offset[1] - z = transformed_xyz[:, 2] + offset[2] - - # Reshape - x = x.reshape(data.shape) - y = y.reshape(data.shape) - z = z.reshape(data.shape) - - # sample - result = scipy.ndimage.map_coordinates(data, [x, y, z], order=1) - return result - - # scan = config.scan.dict() + # Get the scan config + # scan = config.dict() # Get the sample centre centre = np.array(sample.centre) # Read the reconstruction file - tomo_file = mrcfile.mmap(rec_file) + tomo_file = mrcfile.mmap(rec_filename) tomogram = tomo_file.data # Get the size of the volume @@ -126,6 +92,9 @@ def rotate_array(data, rotation, offset): tomo_file.voxel_size["z"], ) ) + assert voxel_size[0] > 0 + assert voxel_size[0] == voxel_size[1] + assert voxel_size[0] == voxel_size[2] size = np.array(tomogram.shape)[[2, 0, 1]] * voxel_size # Loop through the @@ -145,7 +114,15 @@ def rotate_array(data, rotation, offset): if particle_size == 0: half_length = ( - int(ceil(sqrt((xmin - xc) ** 2 + (ymin - yc) ** 2 + (zmin - zc) ** 2))) + int( + ceil( + sqrt( + ((xmin - xc) / voxel_size[0]) ** 2 + + ((ymin - yc) / voxel_size[1]) ** 2 + + ((zmin - zc) / voxel_size[2]) ** 2 + ) + ) + ) + 1 ) else: @@ -154,12 +131,12 @@ def rotate_array(data, rotation, offset): assert len(positions) == len(orientations) num_particles = len(positions) print( - "Averaging %d %s particles with box size %d" % (num_particles, name, length) + "Extracting %d %s particles with box size %d" + % (num_particles, name, length) ) # Create the average array - extract_map: Any = [] - particle_instance = np.zeros(shape=(length, length, length), dtype="float32") + shape = (length, length, length) num = 0 # Sort the positions and orientations by y @@ -167,58 +144,99 @@ def rotate_array(data, rotation, offset): *sorted(zip(positions, orientations), key=lambda x: x[0][1]) ) - # Loop through all the particles - for i, (position, orientation) in enumerate(zip(positions, orientations)): - # Compute p within the volume - # start_position = np.array([0, scan["start_pos"], 0]) - p = position - (centre - size / 2.0) # - start_position - p[2] = size[2] - p[2] - print( - "Particle %d: position = %s, orientation = %s" - % ( - i, - "[ %.1f, %.1f, %.1f ]" % tuple(p), - "[ %.1f, %.1f, %.1f ]" % tuple(orientation), - ) - ) - - # Set the region to extract - x0 = np.floor(p).astype("int32") - half_length - x1 = np.floor(p).astype("int32") + half_length - offset = p - np.floor(p).astype("int32") + # Get the random indices + indices = [list(range(len(positions)))] - # Get the sub tomogram - print("Getting sub tomogram") - sub_tomo = tomogram[x0[1] : x1[1], x0[2] : x1[2], x0[0] : x1[0]] - if sub_tomo.shape == particle_instance.shape: - # Set the data to transform - data = sub_tomo - - # Reorder input vectors - offset = np.array(data.shape)[::-1] / 2 + offset[[1, 2, 0]] - rotation = -np.array(orientation)[[1, 2, 0]] - rotation[1] = -rotation[1] + # Create a file to store particles + handle = h5py.File(extract_file, "w") + handle["voxel_size"] = voxel_size + data_handle = handle.create_dataset( + "data", (0,) + shape, maxshape=(None,) + shape + ) - # Rotate the data - print("Rotating volume") - data = rotate_array(data, rotation, offset) + # Loop through all the particles + with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: + for half_index, data in lazy_map( + executor, + _process_sub_tomo, + _iterate_particles( + indices, + positions, + orientations, + centre, + size, + half_length, + shape, + voxel_size, + tomogram, + ), + ): + # Add the particle to the file + data_handle.resize(num + 1, axis=0) + data_handle[num, :, :, :] = data + num += 1 + print("Count: ", num) - # Add the contribution to the average - extract_map.append(data) - num += 1 +def average_extracted_particles( + particles_filename: str, + half1_filename: str, + half2_filename: str, + num_particles: int = 0, +): + """ + Average the extracted particles - # Average the sub tomograms - print("Extracting %d particles" % num) - extract_map = np.array(extract_map) + """ - # from matplotlib import pylab - # pylab.imshow(average[half_length, :, :]) - # pylab.show() + # Open the particles file + handle = h5py.File(particles_filename, "r") + data = handle["data"] + voxel_size = tuple(handle["voxel_size"][:]) + print("Voxel size: %s" % str(voxel_size)) + + # Get the number of particles + if num_particles is None or num_particles <= 0: + num_particles = data.shape[0] + half_num_particles = num_particles // 2 + assert half_num_particles > 0 + assert num_particles <= data.shape[0] + + # Setup the arrays + half = np.zeros((2,) + data.shape[1:], dtype="float32") + num = np.zeros(2) + + # Get the random indices + indices = list( + np.random.choice(range(data.shape[0]), size=num_particles, replace=False) + ) + indices = [indices[:half_num_particles], indices[half_num_particles:]] - # Save the averaged data - print("Saving extracted particles to %s" % extract_file) - handle = h5py.File(extract_file, "w") - data_handle = handle.create_dataset("data", extract_map.shape, chunks=True) - data_handle[:] = extract_map[:] - handle.close() + # Average the particles + print("Summing particles") + for half_index, particle_indices in enumerate(indices): + for i, particle_index in enumerate(particle_indices): + print( + "Half %d: adding %d / %d" + % (half_index + 1, i + 1, len(particle_indices)) + ) + half[half_index, :, :, :] += data[particle_index, :, :, :] + num[half_index] += 1 + + # Average the sub tomograms + print("Averaging half 1 with %d particles" % num[0]) + print("Averaging half 2 with %d particles" % num[1]) + if num[0] > 0: + half[0, :, :, :] = half[0, :, :, :] / num[0] + if num[1] > 0: + half[1, :, :, :] = half[1, :, :, :] / num[1] + + # Save the averaged data + print("Saving half 1 to %s" % half1_filename) + handle = mrcfile.new(half1_filename, overwrite=True) + handle.set_data(half[0, :, :, :]) + handle.voxel_size = voxel_size + print("Saving half 2 to %s" % half2_filename) + handle = mrcfile.new(half2_filename, overwrite=True) + handle.set_data(half[1, :, :, :]) + handle.voxel_size = voxel_size diff --git a/src/parakeet/command_line/analyse/__init__.py b/src/parakeet/command_line/analyse/__init__.py index eb410477..e7d11117 100644 --- a/src/parakeet/command_line/analyse/__init__.py +++ b/src/parakeet/command_line/analyse/__init__.py @@ -3,6 +3,7 @@ from parakeet.command_line.analyse._correct import * # noqa from parakeet.command_line.analyse._extract import * # noqa from parakeet.command_line.analyse._average_particles import * # noqa +from parakeet.command_line.analyse._average_extracted_particles import * # noqa from parakeet.command_line.analyse._average_all_particles import * # noqa from parakeet.command_line.analyse._refine import * # noqa # fmt: on diff --git a/src/parakeet/command_line/analyse/_average_all_particles.py b/src/parakeet/command_line/analyse/_average_all_particles.py index f6ea0c78..2f3d38cc 100644 --- a/src/parakeet/command_line/analyse/_average_all_particles.py +++ b/src/parakeet/command_line/analyse/_average_all_particles.py @@ -89,6 +89,14 @@ def get_parser(parser: ArgumentParser = None) -> ArgumentParser: dest="particle_size", help="The size of the particles extracted (px)", ) + parser.add_argument( + "-n", + "--num_particles", + type=int, + default=0, + dest="num_particles", + help="The number of particles to use", + ) return parser @@ -107,7 +115,12 @@ def average_all_particles_impl(args): # Do the work parakeet.analyse.average_all_particles( - args.config, args.sample, args.rec, args.average, args.particle_size + args.config, + args.sample, + args.rec, + args.average, + args.particle_size, + args.num_particles, ) # Write some timing stats diff --git a/src/parakeet/command_line/analyse/_average_extracted_particles.py b/src/parakeet/command_line/analyse/_average_extracted_particles.py new file mode 100644 index 00000000..4fbaa812 --- /dev/null +++ b/src/parakeet/command_line/analyse/_average_extracted_particles.py @@ -0,0 +1,116 @@ +# +# parakeet.command_line.analyse.average_extracted_particles.py +# +# Copyright (C) 2019 Diamond Light Source and Rosalind Franklin Institute +# +# Author: James Parkhurst +# +# This code is distributed under the GPLv3 license, a copy of +# which is included in the root directory of this package. +# + + +import logging +import time +import parakeet.analyse +import parakeet.io +import parakeet.command_line +import parakeet.config +import parakeet.microscope +import parakeet.sample +from argparse import ArgumentParser +from typing import List + + +__all__ = ["average_extracted_particles"] + + +# Get the logger +logger = logging.getLogger(__name__) + + +def get_description(): + """ + Get the program description + + """ + return "Perform sub tomogram averaging" + + +def get_parser(parser: ArgumentParser = None) -> ArgumentParser: + """ + Get the parakeet.analyse.average_extracted_particles parser + + """ + + # Initialise the parser + if parser is None: + parser = ArgumentParser(description=get_description()) + + # Add some command line arguments + parser.add_argument( + "-p", + "--particles", + type=str, + default="particles.h5", + dest="particles", + help="The filename for the particles", + ) + parser.add_argument( + "-h1", + "--half1", + type=str, + default="half1.mrc", + dest="half1", + help="The filename for the particle average", + ) + parser.add_argument( + "-h2", + "--half2", + type=str, + default="half2.mrc", + dest="half2", + help="The filename for the particle average", + ) + parser.add_argument( + "-n", + "--num_particles", + type=int, + default=0, + dest="num_particles", + help="The number of particles to use", + ) + + return parser + + +def average_extracted_particles_impl(args): + """ + Perform sub tomogram averaging + + """ + + # Get the start time + start_time = time.time() + + # Configure some basic logging + parakeet.command_line.configure_logging() + + # Do the work + parakeet.analyse.average_extracted_particles( + args.particles, + args.half1, + args.half2, + args.num_particles, + ) + + # Write some timing stats + logger.info("Time taken: %.2f seconds" % (time.time() - start_time)) + + +def average_extracted_particles(args: List[str] = None): + """ + Perform sub tomogram averaging + + """ + average_extracted_particles_impl(get_parser().parse_args(args=args)) diff --git a/src/parakeet/config.py b/src/parakeet/config.py index cd71c220..17ff185d 100644 --- a/src/parakeet/config.py +++ b/src/parakeet/config.py @@ -159,7 +159,8 @@ class MoleculePose(BaseModel): position: Optional[Tuple[float, float, float]] = Field( description=( "The molecule position (A, A, A). Setting this to null or an " - "empty list will cause parakeet to give a random position" + "empty list will cause parakeet to give a random position. " + "The position is given in [x y z] order. " ), examples=[ "position: null # Assign random position", @@ -173,7 +174,8 @@ class MoleculePose(BaseModel): "The molecule orientation defined as a rotation vector where " "the direction of the vector gives the rotation axis and the " "magnitude of the vector gives the rotation angle in radians. Setting " - "this to null or an empty list will cause parakeet to give a random orientation" + "this to null or an empty list will cause parakeet to give a random " + "orientation. The axis is given in [x, y, z] order." ), examples=[ "orienation: null # Assign random orienation", @@ -671,7 +673,7 @@ class IceParameters(BaseModel): s2: float = Field(0.081, description="The standard deviation of gaussian 2") a1: float = Field(0.199, description="The amplitude of gaussian 1") a2: float = Field(0.801, description="The amplitude of gaussian 2") - density: float = Field(0.91, gt=0, description="The density of the ice (g/cm^3)") + density: float = Field(0.94, gt=0, description="The density of the ice (g/cm^3)") class Simulation(BaseModel): diff --git a/src/parakeet/inelastic.py b/src/parakeet/inelastic.py index cdd8f4cb..e99458ae 100644 --- a/src/parakeet/inelastic.py +++ b/src/parakeet/inelastic.py @@ -17,16 +17,24 @@ def effective_thickness(shape, angle): Compute the effective thickness """ - TINY = 1e-10 + TINY = 1e-5 if shape["type"] == "cube": D0 = shape["cube"]["length"] - thickness = D0 / (cos(pi * angle / 180.0) + TINY) + cos_angle = cos(pi * angle / 180.0) + if abs(cos_angle) < TINY: + cos_angle = TINY + thickness = D0 / cos_angle elif shape["type"] == "cuboid": D0 = shape["cuboid"]["length_z"] - thickness = D0 / (cos(pi * angle / 180.0) + TINY) + cos_angle = cos(pi * angle / 180.0) + if abs(cos_angle) < TINY: + cos_angle = TINY + thickness = D0 / cos_angle elif shape["type"] == "cylinder": thickness = shape["cylinder"]["radius"] * 2 - return thickness + if isinstance(thickness, list): + thickness = np.mean(thickness) + return abs(thickness) def zero_loss_fraction(shape, angle): @@ -187,7 +195,9 @@ def energy_loss_distribution(self, energy, thickness): # The energy loss distribution energy_loss_distribution = self.landau(dE, energy, thickness) - energy_loss_distribution /= self.dE_step * np.sum(energy_loss_distribution) + sum_ELD = np.sum(energy_loss_distribution) + if sum_ELD > 0: + energy_loss_distribution /= self.dE_step * sum_ELD # The zero loss distribution zero_loss_distribution = (1.0 / sqrt(pi * self.energy_spread**2)) * np.exp( @@ -274,7 +284,9 @@ def compute_inelastic_component(self, energy, thickness, position, filter_width) # The energy loss distribution P = self.landau(dE, energy, thickness) - P /= self.dE_step * np.sum(P) + sum_P = np.sum(P) + if sum_P > 0: + P /= self.dE_step * sum_P C = np.cumsum(P) * self.dE_step # Compute the fractions for the zero loss and energy losses @@ -312,3 +324,103 @@ def compute_inelastic_component(self, energy, thickness, position, filter_width) # Return the fraction and spread return fraction, spread + + +def get_energy_bins( + energy, + thickness, + energy_spread=0.8, + filter_energy=None, + filter_width=None, + dE_max=200, + dE_step=5, +): + """ + Get some energy bins with weights + """ + + # Ensure min and max such that when we split the mean dE will be at + # sensible locations (i.e. we will have one at zero) + dE_step_sub = 0.01 + assert dE_step < 10 + dE_min = -dE_step * (0.5 + int(10 / dE_step)) + dE_max = dE_step * (0.5 + int(dE_max / dE_step)) + assert dE_min < 0 + assert dE_max > dE_min + + # Check the filter width and step size + if filter_energy is not None and filter_width is not None: + # Adjust the step size + num_step = int(filter_width / dE_step) + dE_step = filter_width / num_step + + # The min and max energies + filter_min = filter_energy - filter_width / 2.0 + filter_max = filter_energy + filter_width / 2.0 + + # Set the bins + bins = [] + for i in range(num_step): + E1 = filter_min + i * dE_step + E2 = E1 + dE_step + bins.append((E1, E2)) + + else: + # Number of steps + num_step = int((dE_max - dE_min) / dE_step) + + bins = [] + for i in range(num_step): + E1 = dE_min + i * dE_step + E2 = E1 + dE_step + bins.append((E1, E2)) + + # Get the distribution of energy losses + optimizer = EnergyFilterOptimizer( + energy_spread, + dE_min=dE_min + dE_step_sub / 2.0, + dE_max=dE_max + dE_step_sub / 2.0, + dE_step=dE_step_sub, + ) + dE, distribution = optimizer.energy_loss_distribution(energy, thickness) + distribution /= np.sum(distribution) + + # The maximum spread + dE_spread_max = sqrt(dE_step**2 / 12) * sqrt(2) + + # Loop over the subdivisions and compute mean energy, spread and total + # weight. For each bin we take the distribution and compute the weighted + # mean energy loss and the weighted variance as the energy spread. The mean + # will always be within the energy bin and the variance will always be > 0 + # and < the variance of the uniform distribution within the bin. + TINY = 1e-5 + nbins = len(bins) + bin_energy = np.zeros(nbins) + bin_spread = np.zeros(nbins) + bin_weight = np.zeros(nbins) + for i in range(nbins): + E1, E2 = bins[i] + select = (dE >= E1) & (dE < E2) + dE_sub = dE[select] + P_sub = distribution[select] + dE_mean = np.mean(dE_sub) + P_tot = np.sum(P_sub) + if P_tot > 1e-7: + P_sub = P_sub / P_tot + dE_mean = np.sum(P_sub * dE_sub) + dE_spread = np.sum(P_sub * ((dE_sub - dE_mean) ** 2)) + dE_spread = sqrt(dE_spread) * sqrt(2) + else: + dE_spread = dE_spread_max + fudge = np.sqrt((dE_step_sub**2 / 12) * len(P_sub)) + assert (E2 - E1) <= (dE_step + TINY) + assert dE_mean >= E1 + assert dE_mean <= E2 + assert dE_spread >= 0 + # assert dE_spread <= (dE_spread_max + fudge) + bin_energy[i] = energy - dE_mean + bin_weight[i] = P_tot + bin_spread[i] = dE_spread + + # Return the bins and weights + return bin_energy, bin_spread, bin_weight diff --git a/src/parakeet/landau.py b/src/parakeet/landau.py index e8cfa270..a8c4e366 100644 --- a/src/parakeet/landau.py +++ b/src/parakeet/landau.py @@ -165,23 +165,25 @@ class Landau(object): """ - def __init__(self, l0=-10, l1=500, dl=0.01): + CACHE = None + + def __init__(self): """ Initialise the class - Params: - l0 (float): The minimum lambda value - l1 (float): The maximum lambda value - dl (float): The lambda step size - """ # Generate the table of values for the universal function - self.l0 = l0 - self.l1 = l1 - self.dl = dl - self.lambda_ = np.arange(l0, l1, dl) - self.phi = np.array([landau(xx) for xx in self.lambda_]) + if Landau.CACHE is None: + l0 = -10 + l1 = 200 + dl = 0.01 + lambda_ = np.arange(l0, l1, dl) + phi = np.array([landau(xx) for xx in lambda_]) + Landau.CACHE = (l0, l1, dl, lambda_, phi) + + # Get the values from the cache + self.l0, self.l1, self.dl, self.lambda_, self.phi = Landau.CACHE def __call__(self, dE, energy, thickness): """ diff --git a/src/parakeet/sample/__init__.py b/src/parakeet/sample/__init__.py index b4af9661..c127cacc 100644 --- a/src/parakeet/sample/__init__.py +++ b/src/parakeet/sample/__init__.py @@ -304,23 +304,33 @@ def shape_enclosed_box(centre, shape): """ + # The margin + margin = np.array(shape.get("margin", (0, 0, 0))) + def cube_enclosed_box(cube): length = cube["length"] - return ((0, 0, 0), (length, length, length)) + return ( + np.array((0, 0, 0)) + margin, + np.array((length, length, length)) - margin, + ) def cuboid_enclosed_box(cuboid): length_x = cuboid["length_x"] length_y = cuboid["length_y"] length_z = cuboid["length_z"] - return ((0, 0, 0), (length_x, length_y, length_z)) + return ( + np.array((0, 0, 0)) + margin, + np.array((length_x, length_y, length_z)) - margin, + ) def cylinder_enclosed_box(cylinder): length = cylinder["length"] radius = np.mean(cylinder["radius"]) - return ( - (radius * (1 - 1 / sqrt(2)), 0, radius * (1 - 1 / sqrt(2))), - (radius * (1 + 1 / sqrt(2)), length, radius * (1 + 1 / sqrt(2))), - ) + x0 = -(radius - margin[0]) / sqrt(2) + radius + x1 = (radius - margin[0]) / sqrt(2) + radius + z0 = -(radius - margin[2]) / sqrt(2) + radius + z1 = (radius - margin[2]) / sqrt(2) + radius + return ((x0, 0 + margin[1], z0), (x1, length - margin[1], z1)) # The enclosed box x0, x1 = np.array( @@ -334,11 +344,8 @@ def cylinder_enclosed_box(cylinder): # The offset offset = centre - (x1 + x0) / 2.0 - # The margin - margin = np.array(shape.get("margin", (0, 0, 0))) - # Return the bounding box - return (x0 + offset + margin, x1 + offset - margin) + return (x0 + offset, x1 + offset) def is_shape_inside_box(box, centre, shape): diff --git a/src/parakeet/sample/distribute.py b/src/parakeet/sample/distribute.py index 527fbd9d..8f1d828a 100644 --- a/src/parakeet/sample/distribute.py +++ b/src/parakeet/sample/distribute.py @@ -228,22 +228,26 @@ def shape_volume_object(centre: tuple, shape: dict): """ - def make_cube_volume(centre, cube): + def make_cube_volume(centre, cube, margin): length = cube["length"] lower = np.array(centre) - length / 2.0 upper = lower + length + lower += np.array(margin) + upper -= np.array(margin) return CuboidVolume(lower, upper) - def make_cuboid_volume(centre, cuboid): + def make_cuboid_volume(centre, cuboid, margin): length_x = cuboid["length_x"] length_y = cuboid["length_y"] length_z = cuboid["length_z"] length = np.array((length_x, length_y, length_z)) lower = np.array(centre) - length / 2.0 upper = lower + length + lower += np.array(margin) + upper -= np.array(margin) return CuboidVolume(lower, upper) - def make_cylinder_volume(centre, cylinder): + def make_cylinder_volume(centre, cylinder, margin): # Get the cylinder params length = cylinder["length"] radius = cylinder["radius"] @@ -267,6 +271,11 @@ def make_cylinder_volume(centre, cylinder): np.array((centre[0], centre[2])) + np.array(list(zip(offset_x, offset_z))) ) + # Add a margin + lower += margin[1] + upper -= margin[1] + radius = [max(1, r - margin[0]) for r in radius] + # Return volume return CylindricalVolume(lower, upper, centre, radius) @@ -274,7 +283,7 @@ def make_cylinder_volume(centre, cylinder): "cube": make_cube_volume, "cuboid": make_cuboid_volume, "cylinder": make_cylinder_volume, - }[shape["type"]](centre, shape[shape["type"]]) + }[shape["type"]](centre, shape[shape["type"]], shape["margin"]) def distribute_particles_uniformly( diff --git a/src/parakeet/scan.py b/src/parakeet/scan.py index 553b51aa..370b6456 100644 --- a/src/parakeet/scan.py +++ b/src/parakeet/scan.py @@ -392,7 +392,7 @@ def manual( elif positions is None and angles is not None: positions = np.zeros(len(angles)) if defocus_offset is None: - defocus_offset = np.zeros(angles.shape[0]) # type: ignore + defocus_offset = np.zeros(len(angles)) # type: ignore assert angles is not None assert positions is not None assert len(angles) == len(positions) diff --git a/src/parakeet/simulate/_cbed.py b/src/parakeet/simulate/_cbed.py index eb369ca6..428a43d6 100644 --- a/src/parakeet/simulate/_cbed.py +++ b/src/parakeet/simulate/_cbed.py @@ -12,7 +12,6 @@ import logging import numpy as np import time -import warnings import parakeet.config import parakeet.dqe import parakeet.freeze @@ -23,11 +22,11 @@ import parakeet.simulate from parakeet.config import Device from parakeet.simulate.simulation import Simulation +from parakeet.simulate.engine import SimulationEngine from parakeet.microscope import Microscope from parakeet.scan import Scan from functools import singledispatch from math import pi -from collections.abc import Iterable from scipy.spatial.transform import Rotation as R @@ -37,12 +36,6 @@ # Get the logger logger = logging.getLogger(__name__) -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") - class CBEDImageSimulator(object): """ @@ -70,85 +63,6 @@ def __init__( self.device = device self.gpu_id = gpu_id - def get_masker( - self, - index, - input_multislice, - pixel_size, - drift, - origin, - offset, - orientation, - shift, - ): - """ - Get the masker object for the ice specification - - """ - - # Create the masker - masker = multem.Masker(input_multislice.nx, input_multislice.ny, pixel_size) - - # Get the sample centre - shape = self.sample.shape - centre = np.array(self.sample.centre) - drift = np.array(drift) - detector_origin = np.array([origin[0], origin[1], 0]) - centre = centre + offset - detector_origin - shift - - # Set the shape - if shape["type"] == "cube": - length = shape["cube"]["length"] - masker.set_cuboid( - ( - centre[0] - length / 2, - centre[1] - length / 2, - centre[2] - length / 2, - ), - (length, length, length), - ) - elif shape["type"] == "cuboid": - length_x = shape["cuboid"]["length_x"] - length_y = shape["cuboid"]["length_y"] - length_z = shape["cuboid"]["length_z"] - masker.set_cuboid( - ( - centre[0] - length_x / 2, - centre[1] - length_y / 2, - centre[2] - length_z / 2, - ), - (length_x, length_y, length_z), - ) - elif shape["type"] == "cylinder": - radius = shape["cylinder"]["radius"] - if not isinstance(radius, Iterable): - radius = [radius] - length = shape["cylinder"]["length"] - offset_x = shape["cylinder"].get("offset_x", None) - offset_z = shape["cylinder"].get("offset_z", None) - axis = shape["cylinder"].get("axis", (0, 1, 0)) - if offset_x is None: - offset_x = [0] * len(radius) - if offset_z is None: - offset_z = [0] * len(radius) - masker.set_cylinder( - (centre[0], centre[1] - length / 2, centre[2]), - axis, - length, - list(radius), - list(offset_x), - list(offset_z), - ) - - # Rotate unless we have a single particle type simulation - if self.scan.is_uniform_angular_scan: - masker.set_rotation(centre, (0, 0, 0)) - else: - masker.set_rotation(centre, orientation) - - # Get the masker - return masker - def __call__(self, index): """ Simulate a single frame @@ -202,28 +116,28 @@ def __call__(self, index): z_centre = self.sample.centre[2] # Create the multem input multislice object - input_multislice = ( - parakeet.simulate.simulation.create_input_multislice_diffraction( - self.microscope, - self.simulation["slice_thickness"], - self.simulation["margin"] + self.simulation["padding"], - "CBED", - z_centre, - ) + simulate = SimulationEngine( + self.device, + self.gpu_id, + self.microscope, + self.simulation["slice_thickness"], + self.simulation["margin"] + self.simulation["padding"], + "CBED", + z_centre, ) # Set the specimen size - input_multislice.spec_lx = x_fov + offset * 2 - input_multislice.spec_ly = y_fov + offset * 2 - input_multislice.spec_lz = self.sample.containing_box[1][2] + simulate.input.spec_lx = x_fov + offset * 2 + simulate.input.spec_ly = y_fov + offset * 2 + simulate.input.spec_lz = self.sample.containing_box[1][2] # Set the beam tilt - input_multislice.theta += beam_tilt_theta - input_multislice.phi += beam_tilt_phi + simulate.input.theta += beam_tilt_theta + simulate.input.phi += beam_tilt_phi # Compute the B factor if self.simulation["radiation_damage_model"]: - input_multislice.static_B_factor = ( + simulate.input.static_B_factor = ( 8 * pi**2 * ( @@ -233,7 +147,7 @@ def __call__(self, index): ) ) else: - input_multislice.static_B_factor = 0 + simulate.input.static_B_factor = 0 # Set the atoms in the input after translating them for the offset atoms = self.sample.get_atoms() @@ -264,7 +178,7 @@ def __call__(self, index): atoms.data = atoms.data[select] # Translate for the detector - input_multislice.spec_atoms = atoms.translate( + simulate.input.spec_atoms = atoms.translate( (offset - origin[0], offset - origin[1], 0) ).to_multem() logger.info(" Got spec atoms") @@ -285,9 +199,8 @@ def __call__(self, index): if self.simulation["ice"] == True: # Get the masker - masker = self.get_masker( + masker = simulate.get_masker( index, - input_multislice, pixel_size, drift, origin, @@ -297,22 +210,21 @@ def __call__(self, index): ) # Run the simulation - output_multislice = multem.simulate(system_conf, input_multislice, masker) + image = simulate.image(masker) else: # Set the incident wave - input_multislice.iw_x = [0] # input_multislice.spec_lx/2 - input_multislice.iw_y = [0] # input_multislice.spec_ly/2 + simulate.input.iw_x = [0] # simulate.input.spec_lx/2 + simulate.input.iw_y = [0] # simulate.input.spec_ly/2 # Run the simulation logger.info("Simulating") - output_multislice = multem.simulate(system_conf, input_multislice) + image = simulate.image(masker) # Get the ideal image data # Multem outputs data in column major format. In C++ and Python we # generally deal with data in row major format so we must do a # transpose here. - image = np.array(output_multislice.data[0].m2psi_tot).T x0 = padding y0 = padding x1 = image.shape[1] - padding diff --git a/src/parakeet/simulate/_ctf.py b/src/parakeet/simulate/_ctf.py index 880828b7..7712f15a 100644 --- a/src/parakeet/simulate/_ctf.py +++ b/src/parakeet/simulate/_ctf.py @@ -17,12 +17,10 @@ import parakeet.futures import parakeet.inelastic import parakeet.sample -import warnings from parakeet.microscope import Microscope from functools import singledispatch from parakeet.simulate.simulation import Simulation -from parakeet.simulate.simulation import create_system_configuration -from parakeet.simulate.simulation import create_input_multislice +from parakeet.simulate.engine import SimulationEngine # Get the logger logger = logging.getLogger(__name__) @@ -31,13 +29,6 @@ __all__ = ["ctf"] -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") - - class CTFSimulator(object): """ A class to do the actual simulation @@ -73,25 +64,24 @@ def __call__(self, index): y_fov = ny * pixel_size # Create the multem system configuration - system_conf = create_system_configuration("cpu") - - # Create the multem input multislice object - input_multislice = create_input_multislice( + simulate = SimulationEngine( + "cpu", + 0, self.microscope, self.simulation["slice_thickness"], self.simulation["margin"], "HRTEM", ) - input_multislice.nx = nx - input_multislice.ny = ny + simulate.input.nx = nx + simulate.input.ny = ny # Set the specimen size - input_multislice.spec_lx = x_fov - input_multislice.spec_ly = y_fov - input_multislice.spec_lz = x_fov # self.sample.containing_box[1][2] + simulate.input.spec_lx = x_fov + simulate.input.spec_ly = y_fov + simulate.input.spec_lz = x_fov # self.sample.containing_box[1][2] # Run the simulation - image = np.array(multem.compute_ctf(system_conf, input_multislice)).T + image = simulate.ctf() image = np.fft.fftshift(image) # Compute the image scaled with Poisson noise diff --git a/src/parakeet/simulate/_exit_wave.py b/src/parakeet/simulate/_exit_wave.py index d381fe07..0c90457e 100644 --- a/src/parakeet/simulate/_exit_wave.py +++ b/src/parakeet/simulate/_exit_wave.py @@ -12,7 +12,6 @@ import logging import numpy as np import time -import warnings import parakeet.config import parakeet.dqe import parakeet.freeze @@ -23,10 +22,10 @@ import parakeet.simulate from parakeet.config import Device from parakeet.simulate.simulation import Simulation +from parakeet.simulate.engine import SimulationEngine from parakeet.microscope import Microscope from functools import singledispatch from math import pi -from collections.abc import Iterable from scipy.spatial.transform import Rotation as R @@ -36,12 +35,6 @@ # Get the logger logger = logging.getLogger(__name__) -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") - class ExitWaveImageSimulator(object): """ @@ -69,96 +62,6 @@ def __init__( self.device = device self.gpu_id = gpu_id - def get_masker( - self, - index, - input_multislice, - pixel_size, - drift, - origin, - offset, - orientation, - shift, - ): - """ - Get the masker object for the ice specification - - """ - - # Create the masker - masker = multem.Masker(input_multislice.nx, input_multislice.ny, pixel_size) - - # Set the ice parameters - ice_parameters = multem.IceParameters() - ice_parameters.m1 = self.simulation["ice_parameters"]["m1"] - ice_parameters.m2 = self.simulation["ice_parameters"]["m2"] - ice_parameters.s1 = self.simulation["ice_parameters"]["s1"] - ice_parameters.s2 = self.simulation["ice_parameters"]["s2"] - ice_parameters.a1 = self.simulation["ice_parameters"]["a1"] - ice_parameters.a2 = self.simulation["ice_parameters"]["a2"] - ice_parameters.density = self.simulation["ice_parameters"]["density"] - masker.set_ice_parameters(ice_parameters) - - # Get the sample centre - shape = self.sample.shape - centre = np.array(self.sample.centre) - drift = np.array(drift) - detector_origin = np.array([origin[0], origin[1], 0]) - centre = centre + offset - detector_origin - shift - - # Set the shape - if shape["type"] == "cube": - length = shape["cube"]["length"] - masker.set_cuboid( - ( - centre[0] - length / 2, - centre[1] - length / 2, - centre[2] - length / 2, - ), - (length, length, length), - ) - elif shape["type"] == "cuboid": - length_x = shape["cuboid"]["length_x"] - length_y = shape["cuboid"]["length_y"] - length_z = shape["cuboid"]["length_z"] - masker.set_cuboid( - ( - centre[0] - length_x / 2, - centre[1] - length_y / 2, - centre[2] - length_z / 2, - ), - (length_x, length_y, length_z), - ) - elif shape["type"] == "cylinder": - radius = shape["cylinder"]["radius"] - if not isinstance(radius, Iterable): - radius = [radius] - length = shape["cylinder"]["length"] - offset_x = shape["cylinder"].get("offset_x", None) - offset_z = shape["cylinder"].get("offset_z", None) - axis = shape["cylinder"].get("axis", (0, 1, 0)) - if offset_x is None: - offset_x = [0] * len(radius) - if offset_z is None: - offset_z = [0] * len(radius) - masker.set_cylinder( - (centre[0], centre[1] - length / 2, centre[2]), - axis, - length, - list(radius), - list(offset_x), - list(offset_z), - ) - - # Rotate unless we have a single particle type simulation - if self.scan.is_uniform_angular_scan: - masker.set_rotation(centre, (0, 0, 0)) - else: - masker.set_rotation(centre, orientation) - - # Get the masker - return masker - def __call__(self, index): """ Simulate a single frame @@ -202,17 +105,13 @@ def __call__(self, index): # padding_offset = padding * pixel_size offset = (padding + margin) * pixel_size - # Create the multem system configuration - system_conf = parakeet.simulate.simulation.create_system_configuration( - self.device, - self.gpu_id, - ) - # The Z centre z_centre = self.sample.centre[2] - # Create the multem input multislice object - input_multislice = parakeet.simulate.simulation.create_input_multislice( + # Create the multem system configuration + simulate = SimulationEngine( + self.device, + self.gpu_id, self.microscope, self.simulation["slice_thickness"], self.simulation["margin"] + self.simulation["padding"], @@ -221,17 +120,17 @@ def __call__(self, index): ) # Set the specimen size - input_multislice.spec_lx = x_fov + offset * 2 - input_multislice.spec_ly = y_fov + offset * 2 - input_multislice.spec_lz = self.sample.containing_box[1][2] + simulate.input.spec_lx = x_fov + offset * 2 + simulate.input.spec_ly = y_fov + offset * 2 + simulate.input.spec_lz = self.sample.containing_box[1][2] # Set the beam tilt - input_multislice.theta += beam_tilt_theta - input_multislice.phi += beam_tilt_phi + simulate.input.theta += beam_tilt_theta + simulate.input.phi += beam_tilt_phi # Compute the B factor if self.simulation["radiation_damage_model"]: - input_multislice.static_B_factor = ( + simulate.input.static_B_factor = ( 8 * pi**2 * ( @@ -241,7 +140,7 @@ def __call__(self, index): ) ) else: - input_multislice.static_B_factor = 0 + simulate.input.static_B_factor = 0 # Set the atoms in the input after translating them for the offset atoms = self.sample.get_atoms() @@ -272,7 +171,7 @@ def __call__(self, index): atoms.data = atoms.data[select] # Translate for the detector - input_multislice.spec_atoms = atoms.translate( + simulate.input.spec_atoms = atoms.translate( (offset - origin[0], offset - origin[1], 0) ).to_multem() logger.info(" Got spec atoms") @@ -293,30 +192,30 @@ def __call__(self, index): if self.simulation["ice"] == True: # Get the masker - masker = self.get_masker( + masker = simulate.masker( index, - input_multislice, pixel_size, - drift, origin, offset, orientation, position, + self.sample, + self.scan, + self.simulation, ) # Run the simulation - output_multislice = multem.simulate(system_conf, input_multislice, masker) + image = simulate.image(masker) else: # Run the simulation logger.info("Simulating") - output_multislice = multem.simulate(system_conf, input_multislice) + image = simulate.image() # Get the ideal image data # Multem outputs data in column major format. In C++ and Python we # generally deal with data in row major format so we must do a # transpose here. - image = np.array(output_multislice.data[0].psi_coh).T x0 = padding y0 = padding x1 = image.shape[1] - padding diff --git a/src/parakeet/simulate/_image.py b/src/parakeet/simulate/_image.py index d65d0efd..04b9c2a9 100644 --- a/src/parakeet/simulate/_image.py +++ b/src/parakeet/simulate/_image.py @@ -24,9 +24,6 @@ from parakeet.simulate.simulation import Simulation -Device = parakeet.config.Device - - __all__ = ["image"] diff --git a/src/parakeet/simulate/_optics.py b/src/parakeet/simulate/_optics.py index 0e684a7c..6bc27419 100644 --- a/src/parakeet/simulate/_optics.py +++ b/src/parakeet/simulate/_optics.py @@ -12,7 +12,6 @@ import copy import logging import numpy as np -import warnings import parakeet.config import parakeet.dqe import parakeet.freeze @@ -24,20 +23,15 @@ from parakeet.microscope import Microscope from parakeet.scan import Scan from functools import singledispatch -from math import sqrt from parakeet.simulate.simulation import Simulation +from parakeet.simulate.engine import SimulationEngine +from parakeet.microscope import Microscope +from parakeet.scan import Scan __all__ = ["optics"] -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") - - # Get the logger logger = logging.getLogger(__name__) @@ -94,28 +88,27 @@ def compute_image( gpu_id, defocus=None, ): - # Create the multem system configuration - system_conf = parakeet.simulate.simulation.create_system_configuration( - device, - gpu_id, - ) - # Set the defocus if defocus is not None: microscope.lens.c_10 = defocus - # Create the multem input multislice object - input_multislice = parakeet.simulate.simulation.create_input_multislice( - microscope, simulation["slice_thickness"], simulation["margin"], "HRTEM" + # Create the simulation engine + simulate = SimulationEngine( + device, + gpu_id, + microscope, + simulation["slice_thickness"], + simulation["margin"], + "HRTEM", ) # Set the specimen size - input_multislice.spec_lx = x_fov + offset * 2 - input_multislice.spec_ly = y_fov + offset * 2 - input_multislice.spec_lz = x_fov # self.sample.containing_box[1][2] + simulate.input.spec_lx = x_fov + offset * 2 + simulate.input.spec_ly = y_fov + offset * 2 + simulate.input.spec_lz = x_fov # self.sample.containing_box[1][2] # Compute and apply the CTF - ctf = np.array(multem.compute_ctf(system_conf, input_multislice)).T + ctf = simulate.ctf() # Add the effect of the phase plate if microscope.phase_plate.use: @@ -213,202 +206,142 @@ def compute_image( # Scale the image by the fraction of electrons image *= electron_fraction - elif self.simulation["inelastic_model"] == "mp_loss": - # Set the filter width - filter_width = self.simulation["mp_loss_width"] # eV - - # Compute the energy and spread of the plasmon peak + else: + # Get the effective thickness thickness = parakeet.inelastic.effective_thickness(shape, angle) # A - peak, sigma = parakeet.inelastic.most_probable_loss( - microscope.beam.energy, shape, angle - ) # eV - - # Save the energy and energy spread - beam_energy = microscope.beam.energy * 1000 # eV - beam_energy_spread = microscope.beam.energy_spread # dE / E - # beam_energy_sigma = (1.0 / sqrt(2)) * beam_energy_spread * beam_energy # eV - - # Set a maximum peak energy loss - peak = min(peak, beam_energy * 0.1) # eV - - # Make optimizer - optimizer = parakeet.inelastic.EnergyFilterOptimizer(dE_min=-60, dE_max=200) - assert self.simulation["mp_loss_position"] in ["peak", "optimal"] - if self.simulation["mp_loss_position"] != "peak": - peak = optimizer(beam_energy, thickness, filter_width=filter_width) - - # Compute elastic fraction and spread - elastic_fraction, elastic_spread = optimizer.compute_elastic_component( - beam_energy, thickness, peak, filter_width - ) - - # Compute inelastic fraction and spread - ( - inelastic_fraction, - inelastic_spread, - ) = optimizer.compute_inelastic_component( - beam_energy, thickness, peak, filter_width - ) - - # Compute the spread - elastic_spread = elastic_spread / beam_energy # dE / E - inelastic_spread = inelastic_spread / beam_energy # dE / E - - # Set the spread for the zero loss image - microscope.beam.energy_spread = elastic_spread # dE / E - - # Compute the zero loss image - image1 = compute_image( - psi, - microscope, - self.simulation, - x_fov, - y_fov, - offset, - self.device, - self.gpu_id, - defocus, - ) - - # Add the energy loss - microscope.beam.energy = (beam_energy - peak) / 1000.0 # keV - - # Compute the energy spread of the plasmon peak - microscope.beam.energy_spread = ( - beam_energy_spread + inelastic_spread - ) # dE / E - print("Energy: %f keV" % microscope.beam.energy) - print("Energy spread: %f ppm" % microscope.beam.energy_spread) - - # Compute the MPL image - image2 = compute_image( - psi, - microscope, - self.simulation, - x_fov, - y_fov, - offset, - self.device, - self.gpu_id, - defocus, - ) - - # Save the energy shift - energy_shift = peak - - # Compute the zero loss and mpl image fraction - electron_fraction = elastic_fraction + inelastic_fraction - - # Add the images incoherently and scale the image by the fraction of electrons - image = elastic_fraction * image1 + inelastic_fraction * image2 - - elif self.simulation["inelastic_model"] == "unfiltered": - # Compute the energy and spread of the plasmon peak - peak, sigma = parakeet.inelastic.most_probable_loss( - microscope.beam.energy, shape, angle - ) # eV - peak = min(peak, 1000 * microscope.beam.energy * 0.1) # eV - spread = sigma * sqrt(2) / (microscope.beam.energy * 1000) # dE / E - - # Compute the zero loss image - image1 = compute_image( - psi, - microscope, - self.simulation, - x_fov, - y_fov, - offset, - self.device, - self.gpu_id, - defocus, - ) - - # Add the energy loss - microscope.beam.energy -= peak / 1000 # keV - - # Compute the energy spread of the plasmon peak - microscope.beam.energy_spread += spread # dE / E - print("Energy: %f keV" % microscope.beam.energy) - print("Energy spread: %f ppm" % microscope.beam.energy_spread) - - # Compute the MPL image - image2 = compute_image( - psi, - microscope, - self.simulation, - x_fov, - y_fov, - offset, - self.device, - self.gpu_id, - defocus, - ) - - # Compute the zero loss and mpl image fraction - zero_loss_fraction = parakeet.inelastic.zero_loss_fraction(shape, angle) - mp_loss_fraction = parakeet.inelastic.mp_loss_fraction(shape, angle) - electron_fraction = zero_loss_fraction + mp_loss_fraction - - # Add the images incoherently and scale the image by the fraction of electrons - image = zero_loss_fraction * image1 + mp_loss_fraction * image2 - - elif self.simulation["inelastic_model"] == "cc_corrected": - # Set the Cs and CC to zero - microscope.lens.c_30 = 0 - microscope.lens.c_c = 0 + if self.simulation["inelastic_model"] == "unfiltered": + # Get the energy bins + bin_energy, bin_spread, bin_weight = parakeet.inelastic.get_energy_bins( + energy=microscope.beam.energy * 1000, # eV + thickness=thickness, + energy_spread=microscope.beam.energy_spread + * microscope.beam.energy + * 1000, # dE + ) - # Compute the energy and spread of the plasmon peak - peak, sigma = parakeet.inelastic.most_probable_loss( - microscope.beam.energy, shape, angle - ) - peak /= 1000.0 - peak = min(peak, microscope.beam.energy * 0.1) - spread = sigma * sqrt(2) / (microscope.beam.energy * 1000) + elif self.simulation["inelastic_model"] == "cc_corrected": + # Get the energy bins + bin_energy, bin_spread, bin_weight = parakeet.inelastic.get_energy_bins( + energy=microscope.beam.energy * 1000, # eV + thickness=thickness, + energy_spread=microscope.beam.energy_spread + * microscope.beam.energy + * 1000, # dE + ) - # Compute the zero loss image - image1 = compute_image( - psi, - microscope, - self.simulation, - x_fov, - y_fov, - offset, - self.device, - self.gpu_id, - defocus, - ) + # Set the Cs and CC to zero + microscope.lens.c_30 = 0 + microscope.lens.c_c = 0 - # Add the energy loss - microscope.beam.energy -= peak + elif self.simulation["inelastic_model"] == "mp_loss": + # Set the filter width + filter_width = self.simulation["mp_loss_width"] # eV - # Compute the energy spread of the plasmon peak - microscope.beam.energy_spread += spread - print("Energy: %f keV" % microscope.beam.energy) - print("Energy spread: %f ppm" % microscope.beam.energy_spread) + # Make optimizer + optimizer = parakeet.inelastic.EnergyFilterOptimizer( + dE_min=-60, dE_max=200 + ) + assert self.simulation["mp_loss_position"] in ["peak", "optimal"] + + # Compute the energy and spread of the plasmon peak + if self.simulation["mp_loss_position"] != "peak": + peak = optimizer( + microscope.beam.energy, thickness, filter_width=filter_width + ) + else: + peak, sigma = parakeet.inelastic.most_probable_loss( + microscope.beam.energy, shape, angle + ) # eV + + # Set a maximum peak energy loss at 10% of beam energy + peak = min(peak, microscope.beam.energy * 1000 * 0.1) # eV + + # Get the energy bins + bin_energy, bin_spread, bin_weight = parakeet.inelastic.get_energy_bins( + energy=microscope.beam.energy * 1000, # eV + thickness=thickness, + energy_spread=microscope.beam.energy_spread + * microscope.beam.energy + * 1000, # dE + filter_energy=peak, + filter_width=filter_width, + ) - # Compute the MPL image - image2 = compute_image( - psi, - microscope, - self.simulation, - x_fov, - y_fov, - offset, - self.device, - self.gpu_id, - defocus, - ) + else: + raise RuntimeError("Unknown inelastic model") + + # Get the threshold to exclude bins that don't contribute much + threshold = min(0.01 / len(bin_energy), max(bin_weight)) + print("Threshold weight: %f" % (threshold)) + + # Select based on threshold + selection = bin_weight >= threshold + bin_energy = bin_energy[selection] + bin_spread = bin_spread[selection] + bin_weight = bin_weight[selection] + + # Get the basic energy and defocus + energy0 = microscope.beam.energy + defocus0 = microscope.lens.c_10 + + # Energy and energy spread + energy1 = bin_energy / 1000.0 # keV + energy_spread1 = bin_spread / bin_energy # dE / E + + # Compute the defocus at this point + # Energy loss is positive. + # Energy loss results in over focus which is also positive + c_c_A = microscope.lens.c_c * 1e7 # A + dE_E = (energy0 - energy1) / energy0 + defocus1 = defocus0 + c_c_A * dE_E # A + + # Adjust defocus to mean + # defocus_mean = np.average(defocus1, weights=bin_weight) + # defocus1 = defocus1 + (defocus0 - defocus_mean) + + # Loop through all energies and sum images + image = None + for energy, energy_spread, defocus, weight in zip( + energy1, energy_spread1, defocus1, bin_weight + ): + # Add the energy loss + microscope.beam.energy = energy # keV + + # Compute the energy spread + microscope.beam.energy_spread = energy_spread # dE / E + + # Print some details + print( + "Energy: %f eV; Energy spread: %f eV; Weight: %f; Defocus: %f" + % ( + microscope.beam.energy * 1000, + microscope.beam.energy_spread * microscope.beam.energy * 1000, + weight, + defocus, + ) + ) - # Compute the zero loss and mpl image fraction - zero_loss_fraction = parakeet.inelastic.zero_loss_fraction(shape, angle) - mp_loss_fraction = parakeet.inelastic.mp_loss_fraction(shape, angle) - electron_fraction = zero_loss_fraction + mp_loss_fraction + # Compute the MPL image + image_n = weight * compute_image( + psi, + microscope, + self.simulation, + x_fov, + y_fov, + offset, + self.device, + self.gpu_id, + defocus, + ) - # Add the images incoherently and scale the image by the fraction of electrons - image = zero_loss_fraction * image1 + mp_loss_fraction * image2 + # Add image component + if image is None: + image = image_n + else: + image += image_n - else: - raise RuntimeError("Unknown inelastic model") + # Compute the electron fraction + electron_fraction = np.sum(bin_weight) # Print the electron fraction print("Electron fraction = %.2f" % electron_fraction) diff --git a/src/parakeet/simulate/_potential.py b/src/parakeet/simulate/_potential.py index 07757e34..78e44056 100644 --- a/src/parakeet/simulate/_potential.py +++ b/src/parakeet/simulate/_potential.py @@ -12,7 +12,6 @@ import logging import mrcfile import numpy as np -import warnings import parakeet.config import parakeet.dqe import parakeet.freeze @@ -20,10 +19,10 @@ import parakeet.inelastic import parakeet.sample from parakeet.config import Device -from parakeet.sample import Sample from parakeet.microscope import Microscope from parakeet.scan import Scan from parakeet.simulate.simulation import Simulation +from parakeet.simulate.engine import SimulationEngine from functools import singledispatch from math import pi, floor from scipy.spatial.transform import Rotation as R @@ -36,13 +35,6 @@ logger = logging.getLogger(__name__) -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") - - class ProjectedPotentialSimulator(object): """ A class to do the actual simulation @@ -109,17 +101,13 @@ def __call__(self, index): x0 = (-offset, -offset) x1 = (x_fov + offset, y_fov + offset) - # Create the multem system configuration - system_conf = parakeet.simulate.simulation.create_system_configuration( - self.device, - self.gpu_id, - ) - # The Z centre z_centre = self.sample.centre[2] # Create the multem input multislice object - input_multislice = parakeet.simulate.simulation.create_input_multislice( + simulate = SimulationEngine( + self.device, + self.gpu_id, self.microscope, self.simulation["slice_thickness"], self.simulation["margin"], @@ -128,9 +116,9 @@ def __call__(self, index): ) # Set the specimen size - input_multislice.spec_lx = x_fov + offset * 2 - input_multislice.spec_ly = y_fov + offset * 2 - input_multislice.spec_lz = self.sample.containing_box[1][2] + simulate.input.spec_lx = x_fov + offset * 2 + simulate.input.spec_ly = y_fov + offset * 2 + simulate.input.spec_lz = self.sample.containing_box[1][2] # Set the atoms in the input after translating them for the offset atoms = self.sample.get_atoms_in_fov(x0, x1) @@ -151,7 +139,7 @@ def __call__(self, index): atoms.data["z"] = coords[:, 2] origin = (0, 0) - input_multislice.spec_atoms = atoms.translate( + simulate.input.spec_atoms = atoms.translate( (offset - origin[0], offset - origin[1], 0) ).to_multem() logger.info(" Got spec atoms") @@ -169,19 +157,8 @@ def __call__(self, index): ) potential.voxel_size = tuple((pixel_size, pixel_size, slice_thickness)) - def callback(z0, z1, V): - V = np.array(V) - zc = (z0 + z1) / 2.0 - index = int(floor((zc - volume_z0) / slice_thickness)) - print( - "Calculating potential for slice: %.2f -> %.2f (index: %d)" - % (z0, z1, index) - ) - if index < potential.data.shape[0]: - potential.data[index, :, :] = V[margin:-margin, margin:-margin].T - - # Run the simulation - multem.compute_projected_potential(system_conf, input_multislice, callback) + # Simulate the potential + simulate.potential(potential, volume_z0) # Compute the image scaled with Poisson noise return (index, None, None) @@ -190,7 +167,7 @@ def callback(z0, z1, V): def simulation_factory( potential_prefix: str, microscope: Microscope, - sample: Sample, + sample: parakeet.sample.Sample, scan: Scan, simulation: dict = None, multiprocessing: dict = None, diff --git a/src/parakeet/simulate/_simple.py b/src/parakeet/simulate/_simple.py index d6dba46a..29b053d7 100644 --- a/src/parakeet/simulate/_simple.py +++ b/src/parakeet/simulate/_simple.py @@ -18,12 +18,10 @@ import parakeet.inelastic import parakeet.io import parakeet.sample -import warnings -from parakeet.microscope import Microscope from functools import singledispatch +from parakeet.microscope import Microscope from parakeet.simulate.simulation import Simulation - -Device = parakeet.config.Device +from parakeet.simulate.engine import SimulationEngine __all__ = ["simple"] @@ -33,13 +31,6 @@ logger = logging.getLogger(__name__) -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") - - class SimpleImageSimulator(object): """ A class to do the actual simulation @@ -91,21 +82,13 @@ def __call__(self, index): # Get the specimen atoms logger.info(f"Simulating image {index+1}") - # Set the rotation angle - # input_multislice.spec_rot_theta = angle - # input_multislice.spec_rot_u0 = simulation.scan.axis - # x0 = (-offset, -offset) # x1 = (x_fov + offset, y_fov + offset) # Create the multem system configuration - system_conf = parakeet.simulate.simulation.create_system_configuration( + simulate = SimulationEngine( self.device, self.gpu_id, - ) - - # Create the multem input multislice object - input_multislice = parakeet.simulate.simulation.create_input_multislice( self.microscope, self.simulation["slice_thickness"], self.simulation["margin"], @@ -113,23 +96,17 @@ def __call__(self, index): ) # Set the specimen size - input_multislice.spec_lx = x_fov + offset * 2 - input_multislice.spec_ly = y_fov + offset * 2 - input_multislice.spec_lz = np.max(self.atoms.data["z"]) + simulate.input.spec_lx = x_fov + offset * 2 + simulate.input.spec_ly = y_fov + offset * 2 + simulate.input.spec_lz = np.max(self.atoms.data["z"]) # Set the atoms in the input after translating them for the offset - input_multislice.spec_atoms = self.atoms.translate( + simulate.input.spec_atoms = self.atoms.translate( (offset, offset, 0) ).to_multem() - # Run the simulation - output_multislice = multem.simulate(system_conf, input_multislice) - - # Get the ideal image data - # Multem outputs data in column major format. In C++ and Python we - # generally deal with data in row major format so we must do a - # transpose here. - image = np.array(output_multislice.data[0].psi_coh).T + # Do the simulation + image = simulate.image() # Print some info psi_tot = np.abs(image) ** 2 diff --git a/src/parakeet/simulate/engine.py b/src/parakeet/simulate/engine.py new file mode 100644 index 00000000..75d37f91 --- /dev/null +++ b/src/parakeet/simulate/engine.py @@ -0,0 +1,513 @@ +# +# parakeet.simulate.engine.py +# +# Copyright (C) 2019 Diamond Light Source and Rosalind Franklin Institute +# +# Author: James Parkhurst +# +# This code is distributed under the GPLv3 license, a copy of +# which is included in the root directory of this package. +# + +import logging +import numpy as np +import warnings +from math import sqrt, pi, floor +from collections.abc import Iterable + + +# Try to input MULTEM +try: + import multem +except ImportError: + warnings.warn("Could not import MULTEM") + + +# Get the logger +logger = logging.getLogger(__name__) + + +def defocus_spread(Cc, dEE, dII, dVV): + """ + From equation 3.41 in Kirkland: Advanced Computing in Electron Microscopy + + The dE, dI, dV are the 1/e half widths or E, I and V respectively + + Args: + Cc (float): The chromatic abberation + dEE (float): dE/E, the fluctuation in the electron energy + dII (float): dI/I, the fluctuation in the lens current + dVV (float): dV/V, the fluctuation in the acceleration voltage + + Returns: + + """ + return Cc * sqrt((dEE) ** 2 + (2 * dII) ** 2 + (dVV) ** 2) + + +class SimulationEngine(object): + """ + A class to encapsulate the multem stuff + + """ + + def __init__( + self, + device, + gpu_id, + microscope, + slice_thickness, + margin, + simulation_type, + centre=None, + ): + """ + Initialise the simulation engine + + """ + + # Save the margin + self.margin = margin + + # Setup the system configuration + self.system_conf = self._create_system_configuration(device, gpu_id) + + # Setup the input multislice + if simulation_type in ["CBED"]: + self.input = self._create_input_multislice_diffraction( + microscope, slice_thickness, margin, simulation_type, centre + ) + else: + self.input = self._create_input_multislice( + microscope, slice_thickness, margin, simulation_type, centre + ) + + def _create_system_configuration(self, device, gpu_id): + """ + Create an appropriate system configuration + + Args: + device (str): The device to use + gpu_id (int): The gpu id + + Returns: + object: The system configuration + + """ + assert device in ["cpu", "gpu"] + + # Initialise the system configuration + system_conf = multem.SystemConfiguration() + + # Set the precision + system_conf.precision = "float" + + # Set the device + if device == "gpu": + if multem.is_gpu_available(): + system_conf.device = "device" + else: + system_conf.device = "host" + warnings.warn("GPU not present, reverting to CPU") + else: + system_conf.device = "host" + + # Set the GPU ID + if gpu_id is not None: + system_conf.gpu_device = gpu_id + + # Print some output + logger.info("Simulating using %s" % system_conf.device) + + # Return the system configuration + return system_conf + + def _create_input_multislice( + self, microscope, slice_thickness, margin, simulation_type, centre=None + ): + """ + Create the input multislice object + + Args: + microscope (object): The microscope object + slice_thickness (float): The slice thickness + margin (int): The pixel margin + + Returns: + object: The input multislice object + + """ + + # Initialise the input and system configuration + input_multislice = multem.Input() + + # Set simulation experiment + input_multislice.simulation_type = simulation_type + + # Electron-Specimen interaction model + input_multislice.interaction_model = "Multislice" + input_multislice.potential_type = "Lobato_0_12" + + # Potential slicing + # XXX If this is set to "Planes" then for the ribosome example I found that + # the simulation would not work well (e.g. The image may have nothing or a + # single point of intensity and nothing else). Best to keep this set to + # dz_Proj. + input_multislice.potential_slicing = "dz_Proj" + + # Electron-Phonon interaction model + input_multislice.pn_model = "Still_Atom" # "Frozen_Phonon" + # input_multislice.pn_model = "Frozen_Phonon" + input_multislice.pn_coh_contrib = 0 + input_multislice.pn_single_conf = False + input_multislice.pn_nconf = 50 + input_multislice.pn_dim = 110 + input_multislice.pn_seed = 300_183 + + # Set the slice thickness + input_multislice.spec_dz = slice_thickness + + # Specimen thickness + input_multislice.thick_type = "Whole_Spec" + + # x-y sampling + input_multislice.nx = microscope.detector.nx + margin * 2 + input_multislice.ny = microscope.detector.ny + margin * 2 + input_multislice.bwl = False + + # Microscope parameters + input_multislice.E_0 = microscope.beam.energy + input_multislice.theta = microscope.beam.theta + input_multislice.phi = microscope.beam.phi + + # Illumination model + input_multislice.illumination_model = "Partial_Coherent" + input_multislice.temporal_spatial_incoh = "Temporal_Spatial" + + # Condenser lens + # source spread function + ssf_sigma = multem.mrad_to_sigma( + input_multislice.E_0, microscope.beam.illumination_semiangle + ) + input_multislice.cond_lens_si_sigma = ssf_sigma + + # Objective lens + input_multislice.obj_lens_m = microscope.lens.m + input_multislice.obj_lens_c_10 = microscope.lens.c_10 + input_multislice.obj_lens_c_12 = microscope.lens.c_12 + input_multislice.obj_lens_phi_12 = microscope.lens.phi_12 + input_multislice.obj_lens_c_21 = microscope.lens.c_21 + input_multislice.obj_lens_phi_21 = microscope.lens.phi_21 + input_multislice.obj_lens_c_23 = microscope.lens.c_23 + input_multislice.obj_lens_phi_23 = microscope.lens.phi_23 + input_multislice.obj_lens_c_30 = microscope.lens.c_30 + input_multislice.obj_lens_c_32 = microscope.lens.c_32 + input_multislice.obj_lens_phi_32 = microscope.lens.phi_32 + input_multislice.obj_lens_c_34 = microscope.lens.c_34 + input_multislice.obj_lens_phi_34 = microscope.lens.phi_34 + input_multislice.obj_lens_c_41 = microscope.lens.c_41 + input_multislice.obj_lens_phi_41 = microscope.lens.phi_41 + input_multislice.obj_lens_c_43 = microscope.lens.c_43 + input_multislice.obj_lens_phi_43 = microscope.lens.phi_43 + input_multislice.obj_lens_c_45 = microscope.lens.c_45 + input_multislice.obj_lens_phi_45 = microscope.lens.phi_45 + input_multislice.obj_lens_c_50 = microscope.lens.c_50 + input_multislice.obj_lens_c_52 = microscope.lens.c_52 + input_multislice.obj_lens_phi_52 = microscope.lens.phi_52 + input_multislice.obj_lens_c_54 = microscope.lens.c_54 + input_multislice.obj_lens_phi_54 = microscope.lens.phi_54 + input_multislice.obj_lens_c_56 = microscope.lens.c_56 + input_multislice.obj_lens_phi_56 = microscope.lens.phi_56 + input_multislice.obj_lens_inner_aper_ang = microscope.lens.inner_aper_ang + input_multislice.obj_lens_outer_aper_ang = microscope.lens.outer_aper_ang + + # Do we have a phase plate + # if microscope.phase_plate: + # input_multislice.phase_shift = pi / 2.0 + + # defocus spread function + input_multislice.obj_lens_ti_sigma = multem.iehwgd_to_sigma( + defocus_spread( + microscope.lens.c_c * 1e-3 / 1e-10, # Convert from mm to A + microscope.beam.energy_spread, + microscope.lens.current_spread, + microscope.beam.acceleration_voltage_spread, + ) + ) + + # zero defocus reference + if centre is not None: + input_multislice.cond_lens_zero_defocus_type = "User_Define" + input_multislice.obj_lens_zero_defocus_type = "User_Define" + input_multislice.cond_lens_zero_defocus_plane = centre + input_multislice.obj_lens_zero_defocus_plane = centre + else: + input_multislice.cond_lens_zero_defocus_type = "Last" + input_multislice.obj_lens_zero_defocus_type = "Last" + + # Return the input multislice object + return input_multislice + + def _create_input_multislice_diffraction( + self, microscope, slice_thickness, margin, simulation_type, centre=None + ): + """ + Create the input multislice object + + Args: + microscope (object): The microscope object + slice_thickness (float): The slice thickness + margin (int): The pixel margin + + Returns: + object: The input multislice object + + """ + + # Initialise the input and system configuration + input_multislice = multem.Input() + + # Set simulation experiment + input_multislice.simulation_type = simulation_type + + # Electron-Specimen interaction model + input_multislice.interaction_model = "Multislice" + input_multislice.potential_type = "Lobato_0_12" + + # Potential slicing + # XXX If this is set to "Planes" then for the ribosome example I found that + # the simulation would not work well (e.g. The image may have nothing or a + # single point of intensity and nothing else). Best to keep this set to + # dz_Proj. + input_multislice.potential_slicing = "dz_Proj" + + # Electron-Phonon interaction model + input_multislice.pn_model = "Still_Atom" # "Frozen_Phonon" + # input_multislice.pn_model = "Frozen_Phonon" + input_multislice.pn_coh_contrib = 0 + input_multislice.pn_single_conf = False + input_multislice.pn_nconf = 50 + input_multislice.pn_dim = 110 + input_multislice.pn_seed = 300_183 + + # Set the slice thickness + input_multislice.spec_dz = slice_thickness + + # Specimen thickness + input_multislice.thick_type = "Whole_Spec" + + # x-y sampling + input_multislice.nx = microscope.detector.nx + margin * 2 + input_multislice.ny = microscope.detector.ny + margin * 2 + input_multislice.bwl = False + + # Microscope parameters + input_multislice.E_0 = microscope.beam.energy + input_multislice.theta = microscope.beam.theta + input_multislice.phi = microscope.beam.phi + + # Illumination model + input_multislice.illumination_model = "Partial_Coherent" + input_multislice.temporal_spatial_incoh = "Temporal_Spatial" + + # Set the incident wave + # For some reason need this to work with CBED + input_multislice.iw_x = [0] # input_multislice.spec_lx/2 + input_multislice.iw_y = [0] # input_multislice.spec_ly/2 + + # Condenser lens + # source spread (illumination semiangle) function + ssf_sigma = multem.mrad_to_sigma( + input_multislice.E_0, microscope.beam.illumination_semiangle + ) + input_multislice.cond_lens_si_sigma = ssf_sigma + + # Objective lens + input_multislice.cond_lens_m = microscope.lens.m + input_multislice.cond_lens_c_10 = microscope.lens.c_10 + input_multislice.cond_lens_c_12 = microscope.lens.c_12 + input_multislice.cond_lens_phi_12 = microscope.lens.phi_12 + input_multislice.cond_lens_c_21 = microscope.lens.c_21 + input_multislice.cond_lens_phi_21 = microscope.lens.phi_21 + input_multislice.cond_lens_c_23 = microscope.lens.c_23 + input_multislice.cond_lens_phi_23 = microscope.lens.phi_23 + input_multislice.cond_lens_c_30 = microscope.lens.c_30 + input_multislice.cond_lens_c_32 = microscope.lens.c_32 + input_multislice.cond_lens_phi_32 = microscope.lens.phi_32 + input_multislice.cond_lens_c_34 = microscope.lens.c_34 + input_multislice.cond_lens_phi_34 = microscope.lens.phi_34 + input_multislice.cond_lens_c_41 = microscope.lens.c_41 + input_multislice.cond_lens_phi_41 = microscope.lens.phi_41 + input_multislice.cond_lens_c_43 = microscope.lens.c_43 + input_multislice.cond_lens_phi_43 = microscope.lens.phi_43 + input_multislice.cond_lens_c_45 = microscope.lens.c_45 + input_multislice.cond_lens_phi_45 = microscope.lens.phi_45 + input_multislice.cond_lens_c_50 = microscope.lens.c_50 + input_multislice.cond_lens_c_52 = microscope.lens.c_52 + input_multislice.cond_lens_phi_52 = microscope.lens.phi_52 + input_multislice.cond_lens_c_54 = microscope.lens.c_54 + input_multislice.cond_lens_phi_54 = microscope.lens.phi_54 + input_multislice.cond_lens_c_56 = microscope.lens.c_56 + input_multislice.cond_lens_phi_56 = microscope.lens.phi_56 + input_multislice.cond_lens_inner_aper_ang = microscope.lens.inner_aper_ang + input_multislice.cond_lens_outer_aper_ang = microscope.lens.outer_aper_ang + + # Do we have a phase plate + if microscope.phase_plate: + input_multislice.phase_shift = pi / 2.0 + + # defocus spread function + input_multislice.obj_lens_ti_sigma = multem.iehwgd_to_sigma( + defocus_spread( + microscope.lens.c_c * 1e-3 / 1e-10, # Convert from mm to A + microscope.beam.energy_spread, + microscope.lens.current_spread, + microscope.beam.acceleration_voltage_spread, + ) + ) + + # zero defocus reference + if centre is not None: + input_multislice.cond_lens_zero_defocus_type = "User_Define" + input_multislice.obj_lens_zero_defocus_type = "User_Define" + input_multislice.cond_lens_zero_defocus_plane = centre + input_multislice.obj_lens_zero_defocus_plane = centre + else: + input_multislice.cond_lens_zero_defocus_type = "Last" + input_multislice.obj_lens_zero_defocus_type = "Last" + + # Return the input multislice object + return input_multislice + + def ctf(self): + """ + Simulate the CTF + + """ + return np.array(multem.compute_ctf(self.system_conf, self.input)).T + + def potential(self, out, volume_z0): + """ + Simulate the potential + + """ + + margin = self.margin + slice_thickness = self.input.spec_dz + + def callback(z0, z1, V): + V = np.array(V) + zc = (z0 + z1) / 2.0 + index = int(floor((zc - volume_z0) / slice_thickness)) + print( + "Calculating potential for slice: %.2f -> %.2f (index: %d)" + % (z0, z1, index) + ) + if index < out.data.shape[0]: + out.data[index, :, :] = V[margin:-margin, margin:-margin].T + + # Run the simulation + multem.compute_projected_potential(self.system_conf, self.input, callback) + + def image(self, masker=None): + """ + Simulate the image + + """ + # Run the simulation + if masker is not None: + output_multislice = multem.simulate(self.system_conf, self.input, masker) + else: + output_multislice = multem.simulate(self.system_conf, self.input) + + # Get the ideal image data + # Multem outputs data in column major format. In C++ and Python we + # generally deal with data in row major format so we must do a + # transpose here. + return np.array(output_multislice.data[0].psi_coh).T + + def masker( + self, + index, + pixel_size, + origin, + offset, + orientation, + shift, + sample, + scan, + simulation, + ): + """ + Get the masker object for the ice specification + + """ + + # Create the masker + masker = multem.Masker(self.input.nx, self.input.ny, pixel_size) + + # Set the ice parameters + ice_parameters = multem.IceParameters() + ice_parameters.m1 = simulation["ice_parameters"]["m1"] + ice_parameters.m2 = simulation["ice_parameters"]["m2"] + ice_parameters.s1 = simulation["ice_parameters"]["s1"] + ice_parameters.s2 = simulation["ice_parameters"]["s2"] + ice_parameters.a1 = simulation["ice_parameters"]["a1"] + ice_parameters.a2 = simulation["ice_parameters"]["a2"] + ice_parameters.density = simulation["ice_parameters"]["density"] + masker.set_ice_parameters(ice_parameters) + + # Get the sample centre + shape = sample.shape + centre = np.array(sample.centre) + detector_origin = np.array([origin[0], origin[1], 0]) + centre = centre + offset - detector_origin - shift + + # Set the shape + if shape["type"] == "cube": + length = shape["cube"]["length"] + masker.set_cuboid( + ( + centre[0] - length / 2, + centre[1] - length / 2, + centre[2] - length / 2, + ), + (length, length, length), + ) + elif shape["type"] == "cuboid": + length_x = shape["cuboid"]["length_x"] + length_y = shape["cuboid"]["length_y"] + length_z = shape["cuboid"]["length_z"] + masker.set_cuboid( + ( + centre[0] - length_x / 2, + centre[1] - length_y / 2, + centre[2] - length_z / 2, + ), + (length_x, length_y, length_z), + ) + elif shape["type"] == "cylinder": + radius = shape["cylinder"]["radius"] + if not isinstance(radius, Iterable): + radius = [radius] + length = shape["cylinder"]["length"] + offset_x = shape["cylinder"].get("offset_x", [0] * len(radius)) + offset_z = shape["cylinder"].get("offset_z", [0] * len(radius)) + axis = shape["cylinder"].get("axis", (0, 1, 0)) + masker.set_cylinder( + (centre[0], centre[1] - length / 2, centre[2]), + axis, + length, + list(radius), + list(offset_x), + list(offset_z), + ) + + # Rotate unless we have a single particle type simulation + if scan.is_uniform_angular_scan: + masker.set_rotation(centre, (0, 0, 0)) + else: + masker.set_rotation(centre, orientation) + + # Get the masker + return masker diff --git a/src/parakeet/simulate/phase_plate.py b/src/parakeet/simulate/phase_plate.py index 48571b92..900a01b5 100644 --- a/src/parakeet/simulate/phase_plate.py +++ b/src/parakeet/simulate/phase_plate.py @@ -1,6 +1,16 @@ import numpy as np +def compute_phase_shift_for_freq(k, phase_shift=np.pi / 2, radius=0.005): + """ + Compute the phase shift from a phase plate + + """ + # Multiply the wave with the phase shift from the phase plate which is + # approximated by a phase shift applied only on the near field terms + return np.exp(1j * phase_shift * (1 - np.exp(-(k**2) / (2 * radius**2)))) + + def compute_phase_shift(shape, pixel_size, phase_shift=np.pi / 2, radius=0.005): """ Compute the phase shift from a phase plate @@ -15,6 +25,4 @@ def compute_phase_shift(shape, pixel_size, phase_shift=np.pi / 2, radius=0.005): # Multiply the wave with the phase shift from the phase plate which is # approximated by a phase shift applied only on the near field terms - return np.fft.ifftshift( - np.exp(1j * phase_shift * (1 - np.exp(-(k**2) / (2 * radius**2)))) - ) + return np.fft.ifftshift(compute_phase_shift_for_freq(k, phase_shift, radius)) diff --git a/src/parakeet/simulate/simulation.py b/src/parakeet/simulate/simulation.py index a0c4f265..8bd3550c 100644 --- a/src/parakeet/simulate/simulation.py +++ b/src/parakeet/simulate/simulation.py @@ -11,355 +11,38 @@ import logging import numpy as np -import warnings import parakeet.config import parakeet.dqe import parakeet.freeze import parakeet.futures import parakeet.inelastic import parakeet.sample -from math import sqrt, pi - -# Try to input MULTEM -try: - import multem -except ImportError: - warnings.warn("Could not import MULTEM") +from typing import Tuple # Get the logger logger = logging.getLogger(__name__) -def defocus_spread(Cc, dEE, dII, dVV): - """ - From equation 3.41 in Kirkland: Advanced Computing in Electron Microscopy - - The dE, dI, dV are the 1/e half widths or E, I and V respectively - - Args: - Cc (float): The chromatic abberation - dEE (float): dE/E, the fluctuation in the electron energy - dII (float): dI/I, the fluctuation in the lens current - dVV (float): dV/V, the fluctuation in the acceleration voltage - - Returns: - - """ - return Cc * sqrt((dEE) ** 2 + (2 * dII) ** 2 + (dVV) ** 2) - - -def create_system_configuration(device, gpu_id=0): - """ - Create an appropriate system configuration - - Args: - device (str): The device to use - - Returns: - object: The system configuration - - """ - assert device in ["cpu", "gpu"] - - # Initialise the system configuration - system_conf = multem.SystemConfiguration() - - # Set the precision - system_conf.precision = "float" - - # Set the device - if device == "gpu": - if multem.is_gpu_available(): - system_conf.device = "device" - else: - system_conf.device = "host" - warnings.warn("GPU not present, reverting to CPU") - else: - system_conf.device = "host" - - # Set the gpu_device - if gpu_id is not None: - system_conf.gpu_device = gpu_id - - # Print some output - logger.info("Simulating using %s" % system_conf.device) - - # Return the system configuration - return system_conf - - -def create_input_multislice( - microscope, slice_thickness, margin, simulation_type, centre=None -): - """ - Create the input multislice object - - Args: - microscope (object): The microscope object - slice_thickness (float): The slice thickness - margin (int): The pixel margin - - Returns: - object: The input multislice object - - """ - - # Initialise the input and system configuration - input_multislice = multem.Input() - - # Set simulation experiment - input_multislice.simulation_type = simulation_type - - # Electron-Specimen interaction model - input_multislice.interaction_model = "Multislice" - input_multislice.potential_type = "Lobato_0_12" - - # Potential slicing - # XXX If this is set to "Planes" then for the ribosome example I found that - # the simulation would not work well (e.g. The image may have nothing or a - # single point of intensity and nothing else). Best to keep this set to - # dz_Proj. - input_multislice.potential_slicing = "dz_Proj" - - # Electron-Phonon interaction model - input_multislice.pn_model = "Still_Atom" # "Frozen_Phonon" - # input_multislice.pn_model = "Frozen_Phonon" - input_multislice.pn_coh_contrib = 0 - input_multislice.pn_single_conf = False - input_multislice.pn_nconf = 50 - input_multislice.pn_dim = 110 - input_multislice.pn_seed = 300_183 - - # Set the slice thickness - input_multislice.spec_dz = slice_thickness - - # Specimen thickness - input_multislice.thick_type = "Whole_Spec" - - # x-y sampling - input_multislice.nx = microscope.detector.nx + margin * 2 - input_multislice.ny = microscope.detector.ny + margin * 2 - input_multislice.bwl = False - - # Microscope parameters - input_multislice.E_0 = microscope.beam.energy - input_multislice.theta = microscope.beam.theta - input_multislice.phi = microscope.beam.phi - - # Illumination model - input_multislice.illumination_model = "Partial_Coherent" - input_multislice.temporal_spatial_incoh = "Temporal_Spatial" - - # Condenser lens - # source spread (illumination semiangle) function - ssf_sigma = multem.mrad_to_sigma( - input_multislice.E_0, microscope.beam.illumination_semiangle - ) - input_multislice.cond_lens_si_sigma = ssf_sigma - - # Objective lens - input_multislice.obj_lens_m = microscope.lens.m - input_multislice.obj_lens_c_10 = microscope.lens.c_10 - input_multislice.obj_lens_c_12 = microscope.lens.c_12 - input_multislice.obj_lens_phi_12 = microscope.lens.phi_12 - input_multislice.obj_lens_c_21 = microscope.lens.c_21 - input_multislice.obj_lens_phi_21 = microscope.lens.phi_21 - input_multislice.obj_lens_c_23 = microscope.lens.c_23 - input_multislice.obj_lens_phi_23 = microscope.lens.phi_23 - input_multislice.obj_lens_c_30 = microscope.lens.c_30 - input_multislice.obj_lens_c_32 = microscope.lens.c_32 - input_multislice.obj_lens_phi_32 = microscope.lens.phi_32 - input_multislice.obj_lens_c_34 = microscope.lens.c_34 - input_multislice.obj_lens_phi_34 = microscope.lens.phi_34 - input_multislice.obj_lens_c_41 = microscope.lens.c_41 - input_multislice.obj_lens_phi_41 = microscope.lens.phi_41 - input_multislice.obj_lens_c_43 = microscope.lens.c_43 - input_multislice.obj_lens_phi_43 = microscope.lens.phi_43 - input_multislice.obj_lens_c_45 = microscope.lens.c_45 - input_multislice.obj_lens_phi_45 = microscope.lens.phi_45 - input_multislice.obj_lens_c_50 = microscope.lens.c_50 - input_multislice.obj_lens_c_52 = microscope.lens.c_52 - input_multislice.obj_lens_phi_52 = microscope.lens.phi_52 - input_multislice.obj_lens_c_54 = microscope.lens.c_54 - input_multislice.obj_lens_phi_54 = microscope.lens.phi_54 - input_multislice.obj_lens_c_56 = microscope.lens.c_56 - input_multislice.obj_lens_phi_56 = microscope.lens.phi_56 - input_multislice.obj_lens_inner_aper_ang = microscope.lens.inner_aper_ang - input_multislice.obj_lens_outer_aper_ang = microscope.lens.outer_aper_ang - - # Do we have a phase plate - # if microscope.phase_plate: - # input_multislice.phase_shift = pi / 2.0 - - # defocus spread function - input_multislice.obj_lens_ti_sigma = multem.iehwgd_to_sigma( - defocus_spread( - microscope.lens.c_c * 1e-3 / 1e-10, # Convert from mm to A - microscope.beam.energy_spread, - microscope.lens.current_spread, - microscope.beam.acceleration_voltage_spread, - ) - ) - - # zero defocus reference - if centre is not None: - input_multislice.cond_lens_zero_defocus_type = "User_Define" - input_multislice.obj_lens_zero_defocus_type = "User_Define" - input_multislice.cond_lens_zero_defocus_plane = centre - input_multislice.obj_lens_zero_defocus_plane = centre - else: - input_multislice.cond_lens_zero_defocus_type = "Last" - input_multislice.obj_lens_zero_defocus_type = "Last" - - # Return the input multislice object - return input_multislice - - -def create_input_multislice_diffraction( - microscope, slice_thickness, margin, simulation_type, centre=None -): - """ - Create the input multislice object - - Args: - microscope (object): The microscope object - slice_thickness (float): The slice thickness - margin (int): The pixel margin - - Returns: - object: The input multislice object - - """ - - # Initialise the input and system configuration - input_multislice = multem.Input() - - # Set simulation experiment - input_multislice.simulation_type = simulation_type - - # Electron-Specimen interaction model - input_multislice.interaction_model = "Multislice" - input_multislice.potential_type = "Lobato_0_12" - - # Potential slicing - # XXX If this is set to "Planes" then for the ribosome example I found that - # the simulation would not work well (e.g. The image may have nothing or a - # single point of intensity and nothing else). Best to keep this set to - # dz_Proj. - input_multislice.potential_slicing = "dz_Proj" - - # Electron-Phonon interaction model - input_multislice.pn_model = "Still_Atom" # "Frozen_Phonon" - # input_multislice.pn_model = "Frozen_Phonon" - input_multislice.pn_coh_contrib = 0 - input_multislice.pn_single_conf = False - input_multislice.pn_nconf = 50 - input_multislice.pn_dim = 110 - input_multislice.pn_seed = 300_183 - - # Set the slice thickness - input_multislice.spec_dz = slice_thickness - - # Specimen thickness - input_multislice.thick_type = "Whole_Spec" - - # x-y sampling - input_multislice.nx = microscope.detector.nx + margin * 2 - input_multislice.ny = microscope.detector.ny + margin * 2 - input_multislice.bwl = False - - # Microscope parameters - input_multislice.E_0 = microscope.beam.energy - input_multislice.theta = microscope.beam.theta - input_multislice.phi = microscope.beam.phi - - # Illumination model - input_multislice.illumination_model = "Partial_Coherent" - input_multislice.temporal_spatial_incoh = "Temporal_Spatial" - - # Set the incident wave - # For some reason need this to work with CBED - input_multislice.iw_x = [0] # input_multislice.spec_lx/2 - input_multislice.iw_y = [0] # input_multislice.spec_ly/2 - - # Condenser lens - # source spread (illumination semiangle) function - ssf_sigma = multem.mrad_to_sigma( - input_multislice.E_0, microscope.beam.illumination_semiangle - ) - input_multislice.cond_lens_si_sigma = ssf_sigma - - # Objective lens - input_multislice.cond_lens_m = microscope.lens.m - input_multislice.cond_lens_c_10 = microscope.lens.c_10 - input_multislice.cond_lens_c_12 = microscope.lens.c_12 - input_multislice.cond_lens_phi_12 = microscope.lens.phi_12 - input_multislice.cond_lens_c_21 = microscope.lens.c_21 - input_multislice.cond_lens_phi_21 = microscope.lens.phi_21 - input_multislice.cond_lens_c_23 = microscope.lens.c_23 - input_multislice.cond_lens_phi_23 = microscope.lens.phi_23 - input_multislice.cond_lens_c_30 = microscope.lens.c_30 - input_multislice.cond_lens_c_32 = microscope.lens.c_32 - input_multislice.cond_lens_phi_32 = microscope.lens.phi_32 - input_multislice.cond_lens_c_34 = microscope.lens.c_34 - input_multislice.cond_lens_phi_34 = microscope.lens.phi_34 - input_multislice.cond_lens_c_41 = microscope.lens.c_41 - input_multislice.cond_lens_phi_41 = microscope.lens.phi_41 - input_multislice.cond_lens_c_43 = microscope.lens.c_43 - input_multislice.cond_lens_phi_43 = microscope.lens.phi_43 - input_multislice.cond_lens_c_45 = microscope.lens.c_45 - input_multislice.cond_lens_phi_45 = microscope.lens.phi_45 - input_multislice.cond_lens_c_50 = microscope.lens.c_50 - input_multislice.cond_lens_c_52 = microscope.lens.c_52 - input_multislice.cond_lens_phi_52 = microscope.lens.phi_52 - input_multislice.cond_lens_c_54 = microscope.lens.c_54 - input_multislice.cond_lens_phi_54 = microscope.lens.phi_54 - input_multislice.cond_lens_c_56 = microscope.lens.c_56 - input_multislice.cond_lens_phi_56 = microscope.lens.phi_56 - input_multislice.cond_lens_inner_aper_ang = microscope.lens.inner_aper_ang - input_multislice.cond_lens_outer_aper_ang = microscope.lens.outer_aper_ang - - # Do we have a phase plate - if microscope.phase_plate: - input_multislice.phase_shift = pi / 2.0 - - # defocus spread function - input_multislice.obj_lens_ti_sigma = multem.iehwgd_to_sigma( - defocus_spread( - microscope.lens.c_c * 1e-3 / 1e-10, # Convert from mm to A - microscope.beam.energy_spread, - microscope.lens.current_spread, - microscope.beam.acceleration_voltage_spread, - ) - ) - - # zero defocus reference - if centre is not None: - input_multislice.cond_lens_zero_defocus_type = "User_Define" - input_multislice.obj_lens_zero_defocus_type = "User_Define" - input_multislice.cond_lens_zero_defocus_plane = centre - input_multislice.obj_lens_zero_defocus_plane = centre - else: - input_multislice.cond_lens_zero_defocus_type = "Last" - input_multislice.obj_lens_zero_defocus_type = "Last" - - # Return the input multislice object - return input_multislice - - class Simulation(object): """ An object to wrap the simulation """ - def __init__(self, image_size, pixel_size, scan=None, nproc=1, simulate_image=None): + def __init__( + self, + image_size: Tuple[int, int], + pixel_size: float, + scan=None, + nproc=1, + simulate_image=None, + ): """ Initialise the simulation Args: - image_size (tuple): The image size + image_size: The image size scan (object): The scan object nproc: The number of processes simulate_image (func): The image simulation function @@ -372,10 +55,10 @@ def __init__(self, image_size, pixel_size, scan=None, nproc=1, simulate_image=No self.simulate_image = simulate_image @property - def shape(self): + def shape(self) -> Tuple[int, int, int]: """ Return - tuple: The simulation data shape + The simulation data shape """ nx = self.image_size[0] @@ -385,7 +68,12 @@ def shape(self): nz = len(self.scan) return (nz, ny, nx) - def angles(self): + def angles(self) -> list: + """ + Return: + The simulation angles + + """ if self.scan is None: return [(0, 0, 0)] return list( diff --git a/src/parakeet/util/calibrate_ice_model.py b/src/parakeet/util/calibrate_ice_model.py index 7d989e5b..af8f8420 100644 --- a/src/parakeet/util/calibrate_ice_model.py +++ b/src/parakeet/util/calibrate_ice_model.py @@ -1063,8 +1063,8 @@ def plot_all_mean_and_power(pixel_size, stats_list, power_list): ) = map(np.array, zip(*stats_list)) width = 0.0393701 * 190 - height = (4 / 8) * width - fig, ax = pylab.subplots(figsize=(width, height), ncols=2, constrained_layout=True) + height = (3 / 8) * width + fig, ax = pylab.subplots(figsize=(width, height), ncols=3, constrained_layout=True) l1 = ax[0].plot(pixel_size, p_real, label="Physical (real)") l2 = ax[0].plot(pixel_size, p_imag, label="Physical (imag)") l3 = ax[0].plot(pixel_size, r_real, label="Random (real)") @@ -1097,18 +1097,30 @@ def plot_all_mean_and_power(pixel_size, stats_list, power_list): color=l4[0].get_color(), alpha=0.3, ) - ax[0].legend(loc="lower right") - ax[0].set_xlabel("Pixel size (A)\n(a)") - ax[0].set_ylabel("Exit wave mean and standard deviation") + ax[0].legend(loc="lower right", fontsize=6) + ax[0].set_xlabel("Pixel size (Å)\n(a)") + ax[0].set_ylabel("Exit wave mean") + + ax[1].scatter(pixel_size, r_real - p_real, label="Real") + ax[1].scatter(pixel_size, r_imag - p_imag, label="Imag") + ax[1].set_ylim((-0.01, 0.01)) + ax[1].set_yticks([-0.01, 0, 0.01]) + ax[1].legend(fontsize=8) + ax[1].set_xlabel("Pixel size (Å)\n(b)") + ax[1].set_ylabel("Difference in exit wave mean\n(physical - random)") + ax[0].tick_params(axis="both", which="major", labelsize=8) + ax[1].tick_params(axis="both", which="major", labelsize=8) + ax[2].tick_params(axis="both", which="major", labelsize=8) + ax[2].set(yticklabels=[]) cycle = pylab.rcParams["axes.prop_cycle"].by_key()["color"] for ps, power in zip(pixel_size, power_list): - p1 = ax[1].plot(power[0], power[1], color=cycle[0], alpha=0.5) - p2 = ax[1].plot(power[2], power[3], color=cycle[1], alpha=0.5) - ax[1].set_xlabel("Spatial frequency (1/Å)\n(b)") - ax[1].set_ylabel("Power spectrum") - ax[1].legend(handles=[p1[0], p2[0]], labels=["Physical", "Random"]) - ax[1].set_xlim(0, 1.0) + p1 = ax[2].plot(power[0], power[1], color=cycle[0], alpha=0.5) + p2 = ax[2].plot(power[2], power[3], color=cycle[1], alpha=0.5) + ax[2].set_xlabel("Spatial frequency (1/Å)\n(c)") + ax[2].set_ylabel("Power spectrum") + ax[2].legend(handles=[p1[0], p2[0]], labels=["Physical", "Random"], fontsize=8) + ax[2].set_xlim(0, 1.0) # pylab.show() fig.savefig("mean_and_power.png", dpi=300, bbox_inches="tight") pylab.close("all") @@ -1183,7 +1195,7 @@ def validate(): plot_all_edge(pixel_size, edge_list) -if __name__ == "__main__": +def main(): # Create the argument parser parser = argparse.ArgumentParser(description="Do the ice model configuration") @@ -1220,3 +1232,7 @@ def validate(): # Do the validation if args.validate: validate() + + +if __name__ == "__main__": + main() diff --git a/tests/test_inelastic.py b/tests/test_inelastic.py index adb4851f..24c1bf4a 100644 --- a/tests/test_inelastic.py +++ b/tests/test_inelastic.py @@ -1,5 +1,6 @@ import pytest -from math import exp +import numpy as np +from math import exp, sqrt from parakeet import inelastic @@ -103,3 +104,17 @@ def test_most_probable_loss(): peak, sigma = inelastic.most_probable_loss(300, cube, 0) assert peak == pytest.approx(17.92806966151457) assert sigma == pytest.approx(5.300095984425282) + + +@pytest.mark.parametrize("thickness", [100, 1000, 4000]) +def test_get_energy_bins(thickness): + bin_energy, bin_spread, bin_weight = inelastic.get_energy_bins( + energy=300000, thickness=thickness, energy_spread=0.798 + ) + + assert np.min(bin_weight) >= 0 + assert np.max(bin_weight) <= 1 + assert np.isclose(np.sum(bin_weight), 1) + assert np.argmax(bin_weight) == 2 + assert np.max(bin_spread) <= sqrt(5**2 / 12) * sqrt(2) + 0.05 + assert np.min(bin_spread) >= 0