-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #214 from slaclab/image_refactoring
Image refactoring
- Loading branch information
Showing
12 changed files
with
182 additions
and
195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,3 +191,6 @@ dkms.conf | |
|
||
# Added by vscode | ||
.vscode | ||
|
||
# development folder | ||
dev |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
from lcls_tools.common.image.fit import ImageProjectionFitResult | ||
|
||
|
||
def plot_image_projection_fit(result: ImageProjectionFitResult): | ||
""" | ||
plot image and projection data for validation | ||
""" | ||
fig, ax = plt.subplots(3, 1) | ||
fig.set_size_inches(4, 9) | ||
|
||
image = result.processed_image | ||
ax[0].imshow(image) | ||
|
||
projections = { | ||
"x": np.array(np.sum(image, axis=0)), | ||
"y": np.array(np.sum(image, axis=1)) | ||
} | ||
|
||
ax[0].plot(*result.centroid, "+r") | ||
|
||
# plot data and model fit | ||
for i, name in enumerate(["x", "y"]): | ||
fit_params = getattr(result, f"{name}_projection_fit_parameters") | ||
ax[i + 1].text(0.01, 0.99, | ||
"\n".join([ | ||
f"{name}: {int(val)}" for name, val in | ||
fit_params.items() | ||
]), | ||
transform=ax[i + 1].transAxes, | ||
ha='left', va='top', fontsize=10) | ||
x = np.arange(len(projections[name])) | ||
|
||
ax[i + 1].plot(projections[name], label="data") | ||
fit_param_numpy = np.array([fit_params[name] for name in | ||
result.projection_fit_method.parameters.parameters]) | ||
ax[i + 1].plot(result.projection_fit_method._forward(x, fit_param_numpy), | ||
label="model fit") | ||
|
||
return fig, ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional, List | ||
|
||
import numpy as np | ||
from numpy import ndarray | ||
from pydantic import BaseModel, ConfigDict, PositiveFloat, Field | ||
|
||
from lcls_tools.common.data.fit.method_base import MethodBase | ||
from lcls_tools.common.data.fit.methods import GaussianModel | ||
from lcls_tools.common.data.fit.projection import ProjectionFit | ||
from lcls_tools.common.image.processing import ImageProcessor | ||
|
||
|
||
class ImageFitResult(BaseModel): | ||
centroid: List[float] = Field(min_length=2, max_length=2) | ||
rms_size: List[float] = Field(min_length=2, max_length=2) | ||
total_intensity: PositiveFloat | ||
processed_image: ndarray | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
|
||
class ImageProjectionFitResult(ImageFitResult): | ||
projection_fit_method: MethodBase | ||
x_projection_fit_parameters: dict[str, float] | ||
y_projection_fit_parameters: dict[str, float] | ||
|
||
|
||
class ImageFit(BaseModel, ABC): | ||
""" | ||
Abstract class for determining beam properties from an image | ||
""" | ||
image_processor: Optional[ImageProcessor] = ImageProcessor() | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
def fit_image(self, image: ndarray) -> ImageFitResult: | ||
""" | ||
Public method to determine beam properties from an image, including initial | ||
image processing, internal image fitting method, and image validation. | ||
""" | ||
processed_image = self.image_processor.auto_process(image) | ||
fit_result = self._fit_image(processed_image) | ||
return fit_result | ||
|
||
@abstractmethod | ||
def _fit_image(self, image: ndarray) -> ImageFitResult: | ||
""" | ||
Private image fitting method to be overwritten by subclasses. Expected to | ||
return a ImageFitResult dataclass. | ||
""" | ||
... | ||
|
||
|
||
class ImageProjectionFit(ImageFit): | ||
""" | ||
Image fitting class that gets the beam size and location by independently fitting | ||
the x/y projections. The default configuration uses a Gaussian fitting of the | ||
profile with prior distributions placed on the model parameters. | ||
""" | ||
projection_fit_method: Optional[MethodBase] = GaussianModel(use_priors=True) | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
def _fit_image(self, image: ndarray) -> ImageProjectionFitResult: | ||
x_projection = np.array(np.sum(image, axis=0)) | ||
y_projection = np.array(np.sum(image, axis=1)) | ||
|
||
proj_fit = ProjectionFit(model=self.projection_fit_method) | ||
|
||
x_parameters = proj_fit.fit_projection(x_projection) | ||
y_parameters = proj_fit.fit_projection(y_projection) | ||
|
||
result = ImageProjectionFitResult( | ||
centroid=[x_parameters["mean"], y_parameters["mean"]], | ||
rms_size=[x_parameters["sigma"], y_parameters["sigma"]], | ||
total_intensity=image.sum(), | ||
x_projection_fit_parameters=x_parameters, | ||
y_projection_fit_parameters=y_parameters, | ||
processed_image=image, | ||
projection_fit_method=self.projection_fit_method, | ||
) | ||
|
||
return result |
Oops, something went wrong.