Skip to content

Commit

Permalink
Merge pull request #214 from slaclab/image_refactoring
Browse files Browse the repository at this point in the history
Image refactoring
  • Loading branch information
roussel-ryan authored Dec 4, 2024
2 parents b14ffbf + 7f75da6 commit 91e650e
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 195 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,6 @@ dkms.conf

# Added by vscode
.vscode

# development folder
dev
125 changes: 0 additions & 125 deletions lcls_tools/common/data/fit/gaussian_fit.py

This file was deleted.

11 changes: 7 additions & 4 deletions lcls_tools/common/data/fit/method_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional

import numpy as np

from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -49,7 +51,7 @@ def priors(self, priors: dict[str, float]):
# TODO: define properties


class MethodBase(ABC):
class MethodBase(ABC, BaseModel):
"""
Base abstract class for all fit methods, which serves as the bare minimum
skeleton code needed. Should be used only as a parent class to all method
Expand All @@ -63,7 +65,9 @@ class MethodBase(ABC):
and upper bound on for acceptable values of each parameter)
"""

parameters: ModelParameters = None
parameters: ModelParameters
use_priors: Optional[bool] = False
fitted_params_dict: Optional[dict] = None

@abstractmethod
def find_init_values(self) -> list:
Expand Down Expand Up @@ -121,10 +125,9 @@ def loss(
method_parameter_list: np.ndarray,
x: np.ndarray,
y: np.ndarray,
use_priors: bool = False,
):
loss_temp = -self._log_likelihood(x, y, method_parameter_list)
if use_priors:
if self.use_priors:
loss_temp = loss_temp - self._log_prior(method_parameter_list)
return loss_temp

Expand Down
4 changes: 2 additions & 2 deletions lcls_tools/common/data/fit/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class GaussianModel(MethodBase):
density functions to match that data.
"""

parameters = gaussian_parameters
parameters: ModelParameters = gaussian_parameters

def find_init_values(self) -> dict:
"""Fit data without optimization, return values."""
Expand All @@ -50,7 +50,7 @@ def find_init_values(self) -> dict:
self.parameters.initial_values = init_values
return init_values

def find_priors(self) -> dict:
def find_priors(self, **kwargs) -> dict:
"""
Do initial guesses based on data and make distribution from that guess.
"""
Expand Down
15 changes: 9 additions & 6 deletions lcls_tools/common/data/fit/projection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional

import numpy as np
import scipy.optimize
import scipy.signal
from pydantic import BaseModel, ConfigDict

from lcls_tools.common.data.fit.method_base import MethodBase
from lcls_tools.common.data.fit.methods import GaussianModel

Expand All @@ -27,8 +30,7 @@ class ProjectionFit(BaseModel):

# TODO: come up with better name
model_config = ConfigDict(arbitrary_types_allowed=True)
model: MethodBase = GaussianModel()
use_priors: bool = False
model: Optional[MethodBase] = GaussianModel()

def normalize(self, data: np.ndarray) -> np.ndarray:
"""
Expand All @@ -40,8 +42,8 @@ def normalize(self, data: np.ndarray) -> np.ndarray:
return normalized_data

def unnormalize_model_params(
self, method_params_dict: dict, projection_data: np.ndarray
) -> np.ndarray:
self, method_params_dict: dict, projection_data: np.ndarray
) -> dict:
"""
Takes fitted and normalized params and returns them
to unnormalized values i.e the true fitted values of the distribution
Expand Down Expand Up @@ -78,15 +80,15 @@ def fit_model(self) -> scipy.optimize._optimize.OptimizeResult:
res = scipy.optimize.minimize(
self.model.loss,
init_values,
args=(x, y, self.use_priors),
args=(x, y),
bounds=bounds,
method="Powell",
)
return res

def fit_projection(self, projection_data: np.ndarray) -> dict:
"""
type is dict[str, float]
Return type is dict[str, float]
Wrapper function that does all necessary steps to fit 1d array.
Returns a dictionary where the keys are the model params and their
values are the params fitted to the data
Expand All @@ -101,4 +103,5 @@ def fit_projection(self, projection_data: np.ndarray) -> dict:
fitted_params_dict[param] = (res.x)[i]
self.model.fitted_params_dict = fitted_params_dict.copy()
params_dict = self.unnormalize_model_params(fitted_params_dict, projection_data)

return params_dict
42 changes: 42 additions & 0 deletions lcls_tools/common/frontend/plotting/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from matplotlib import pyplot as plt

from lcls_tools.common.image.fit import ImageProjectionFitResult


def plot_image_projection_fit(result: ImageProjectionFitResult):
"""
plot image and projection data for validation
"""
fig, ax = plt.subplots(3, 1)
fig.set_size_inches(4, 9)

image = result.processed_image
ax[0].imshow(image)

projections = {
"x": np.array(np.sum(image, axis=0)),
"y": np.array(np.sum(image, axis=1))
}

ax[0].plot(*result.centroid, "+r")

# plot data and model fit
for i, name in enumerate(["x", "y"]):
fit_params = getattr(result, f"{name}_projection_fit_parameters")
ax[i + 1].text(0.01, 0.99,
"\n".join([
f"{name}: {int(val)}" for name, val in
fit_params.items()
]),
transform=ax[i + 1].transAxes,
ha='left', va='top', fontsize=10)
x = np.arange(len(projections[name]))

ax[i + 1].plot(projections[name], label="data")
fit_param_numpy = np.array([fit_params[name] for name in
result.projection_fit_method.parameters.parameters])
ax[i + 1].plot(result.projection_fit_method._forward(x, fit_param_numpy),
label="model fit")

return fig, ax
82 changes: 82 additions & 0 deletions lcls_tools/common/image/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import ABC, abstractmethod
from typing import Optional, List

import numpy as np
from numpy import ndarray
from pydantic import BaseModel, ConfigDict, PositiveFloat, Field

from lcls_tools.common.data.fit.method_base import MethodBase
from lcls_tools.common.data.fit.methods import GaussianModel
from lcls_tools.common.data.fit.projection import ProjectionFit
from lcls_tools.common.image.processing import ImageProcessor


class ImageFitResult(BaseModel):
centroid: List[float] = Field(min_length=2, max_length=2)
rms_size: List[float] = Field(min_length=2, max_length=2)
total_intensity: PositiveFloat
processed_image: ndarray
model_config = ConfigDict(arbitrary_types_allowed=True)


class ImageProjectionFitResult(ImageFitResult):
projection_fit_method: MethodBase
x_projection_fit_parameters: dict[str, float]
y_projection_fit_parameters: dict[str, float]


class ImageFit(BaseModel, ABC):
"""
Abstract class for determining beam properties from an image
"""
image_processor: Optional[ImageProcessor] = ImageProcessor()
model_config = ConfigDict(arbitrary_types_allowed=True)

def fit_image(self, image: ndarray) -> ImageFitResult:
"""
Public method to determine beam properties from an image, including initial
image processing, internal image fitting method, and image validation.
"""
processed_image = self.image_processor.auto_process(image)
fit_result = self._fit_image(processed_image)
return fit_result

@abstractmethod
def _fit_image(self, image: ndarray) -> ImageFitResult:
"""
Private image fitting method to be overwritten by subclasses. Expected to
return a ImageFitResult dataclass.
"""
...


class ImageProjectionFit(ImageFit):
"""
Image fitting class that gets the beam size and location by independently fitting
the x/y projections. The default configuration uses a Gaussian fitting of the
profile with prior distributions placed on the model parameters.
"""
projection_fit_method: Optional[MethodBase] = GaussianModel(use_priors=True)
model_config = ConfigDict(arbitrary_types_allowed=True)

def _fit_image(self, image: ndarray) -> ImageProjectionFitResult:
x_projection = np.array(np.sum(image, axis=0))
y_projection = np.array(np.sum(image, axis=1))

proj_fit = ProjectionFit(model=self.projection_fit_method)

x_parameters = proj_fit.fit_projection(x_projection)
y_parameters = proj_fit.fit_projection(y_projection)

result = ImageProjectionFitResult(
centroid=[x_parameters["mean"], y_parameters["mean"]],
rms_size=[x_parameters["sigma"], y_parameters["sigma"]],
total_intensity=image.sum(),
x_projection_fit_parameters=x_parameters,
y_projection_fit_parameters=y_parameters,
processed_image=image,
projection_fit_method=self.projection_fit_method,
)

return result
Loading

0 comments on commit 91e650e

Please sign in to comment.