Skip to content

Commit

Permalink
refactor windows into spectrum_regions.py from config
Browse files Browse the repository at this point in the history
  • Loading branch information
David Wallace committed Mar 6, 2024
1 parent e03e136 commit df2ecb2
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 35 deletions.
1 change: 0 additions & 1 deletion src/raman_fitting/config/default_models/first_order.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

[first_order]

[first_order.models]
Expand Down
1 change: 0 additions & 1 deletion src/raman_fitting/config/default_models/second_order.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

[second_order]

[second_order.models]
Expand Down
10 changes: 10 additions & 0 deletions src/raman_fitting/config/default_models/spectrum_windows.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[spectrum]

[spectrum.windows]
full = {"min" = 200, "max" = 3600}
full_first_and_second = {"min" = 800, "max" = 3500}
low = {"min" = 150, "max" = 850, "extra_margin" = 10}
first_order = {"min" = 900, "max" = 2000}
mid = {"min" = 1850, "max" = 2150, "extra_margin" = 10}
normalization = {"min" = 1500, "max" = 1675, "extra_margin" = 10}
second_order = {"min" = 2150, "max" = 3380}
20 changes: 20 additions & 0 deletions src/raman_fitting/models/deconvolution/spectrum_regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import StrEnum
from typing import Dict

from pydantic import BaseModel
from raman_fitting.config.default_models import load_config_from_toml_files


def get_default_regions_from_toml_files() -> Dict[str, Dict[str, float]]:
default_windows = load_config_from_toml_files().get("spectrum", {}).get("windows", {})
return default_windows


WindowNames = StrEnum('WindowNames', " ".join(get_default_regions_from_toml_files()), module=__name__)


class SpectrumWindowLimits(BaseModel):
name: WindowNames
min: int
max: int
extra_margin: int = 20
8 changes: 6 additions & 2 deletions src/raman_fitting/models/fit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from lmfit.model import ModelResult

from raman_fitting.models.deconvolution.base_model import BaseLMFitModel
from raman_fitting.models.deconvolution.spectrum_regions import WindowNames

from raman_fitting.models.spectrum import SpectrumData

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +31,7 @@ class SpectrumFitModel(BaseModel):

spectrum: SpectrumData
model: BaseLMFitModel
window: WindowNames
fit_kwargs: Dict = Field(default_factory=dict)
fit_result: ModelResult = Field(None, init_var=False)
elapsed_time: float = Field(0, init_var=False, repr=False)
Expand All @@ -43,7 +46,7 @@ def match_window_names(self):
)
return self

def run_fit(self) -> ModelResult:
def run_fit(self) -> None:
if "method" not in self.fit_kwargs:
self.fit_kwargs["method"] = "leastsq"
lmfit_model = self.model.lmfit_model
Expand All @@ -54,6 +57,7 @@ def run_fit(self) -> ModelResult:
self.elapsed_time = elapsed_seconds
self.fit_result = fit_result


def process_fit_results(self):
# TODO add parameter post processing steps
self.fit_result
Expand All @@ -77,7 +81,7 @@ def run_fit(


if __name__ == "__main__":
from raman_fitting.config.settings import settings
from raman_fitting.config.base_settings import settings

test_fixtures = list(settings.internal_paths.example_fixtures.glob("*txt"))
file = [i for i in test_fixtures if "_pos4" in i.stem][0]
Expand Down
36 changes: 5 additions & 31 deletions src/raman_fitting/models/splitter.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,11 @@
from enum import Enum
from typing import Dict, Any
import numpy as np

from pydantic import BaseModel, model_validator, Field
from .spectrum import SpectrumData
from .deconvolution.spectrum_regions import SpectrumWindowLimits, WindowNames, get_default_regions_from_toml_files


class WindowNames(str, Enum):
full = "full"
full_first_and_second = "full_first_and_second"
low = "low"
first_order = "first_order"
mid = "mid"
second_order = "second_order"
normalization = "normalization"


SPECTRUM_WINDOWS_LIMITS = {
WindowNames.full: {"min": 200, "max": 3600},
WindowNames.full_first_and_second: {"min": 800, "max": 3500},
WindowNames.low: {"min": 150, "max": 850, "extra_margin": 10},
WindowNames.first_order: {"min": 900, "max": 2000},
WindowNames.mid: {"min": 1850, "max": 2150, "extra_margin": 10},
WindowNames.second_order: {"min": 2150, "max": 3380},
WindowNames.normalization: {"min": 1500, "max": 1675, "extra_margin": 10},
}


class SpectrumWindowLimits(BaseModel):
name: WindowNames
min: int
max: int
extra_margin: int = 20


class SplittedSpectrum(BaseModel):
spectrum: SpectrumData
Expand Down Expand Up @@ -68,10 +41,11 @@ def get_window(self, window_name: WindowNames):
return self.spec_windows[spec_window_key]


def get_default_spectrum_window_limits() -> Dict[str, SpectrumWindowLimits]:
def get_default_spectrum_window_limits(windows_mapping: Dict = None) -> Dict[str, SpectrumWindowLimits]:
if windows_mapping is None:
windows_mapping = get_default_regions_from_toml_files()
windows = {}
for window_type, window_config in SPECTRUM_WINDOWS_LIMITS.items():
window_name = window_type.name
for window_name, window_config in windows_mapping.items():
windows[window_name] = SpectrumWindowLimits(name=window_name, **window_config)
return windows

Expand Down

0 comments on commit df2ecb2

Please sign in to comment.