Skip to content

Commit

Permalink
Change noise field to default_factory (#26)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Pinder <[email protected]>
  • Loading branch information
thomaspinder and Thomas Pinder authored Oct 23, 2024
1 parent 8cfdcfa commit 311dbd4
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion src/causal_validation/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.0.8"
__version__ = "0.0.9"

__all__ = ["__version__"]
6 changes: 4 additions & 2 deletions src/causal_validation/transforms/noise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Tuple

from jaxtyping import Float
Expand All @@ -18,7 +18,9 @@ class Noise(AdditiveTransform):
Normal with 0 loc and 0.1 scale.
"""

noise_dist: TimeVaryingParameter = TimeVaryingParameter(sampling_dist=norm(0, 0.1))
noise_dist: TimeVaryingParameter = field(
default_factory=lambda: TimeVaryingParameter(sampling_dist=norm(0, 0.1))
)
_slots: Tuple[str] = ("noise_dist",)

def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
Expand Down

0 comments on commit 311dbd4

Please sign in to comment.