Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Oct 25, 2024
1 parent 54e8cb9 commit 4defb14
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(
transforms: TransformsSeqType,
p: float,
mask_interpolation: int | None = None,
seed: int | None = None,
):
if isinstance(transforms, (BaseCompose, BasicTransform)):
warnings.warn(
Expand All @@ -125,9 +124,9 @@ def __init__(
self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {}
self._set_keys()
self.set_mask_interpolation(mask_interpolation)
self.seed = seed
self.random_generator = np.random.default_rng(seed)
self.set_random_state(seed) # This will propagate to children
self.seed: int | None = None
self.random_generator = np.random.default_rng(self.seed)
self.set_random_state(self.seed) # This will propagate to children

def set_random_state(self, seed: int | None) -> None:
"""Set random state for this compose and all nested transforms.
Expand All @@ -137,6 +136,7 @@ def set_random_state(self, seed: int | None) -> None:
"""
self.seed = seed
self.random_generator = np.random.default_rng(seed)

if seed is not None:
random.seed(seed) # Set standard library random seed

Expand Down
4 changes: 3 additions & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def __init__(self, p: float = 0.5, always_apply: bool | None = None):
self.random_generator = np.random.default_rng(self.seed)
self.py_random = random.Random(self.seed) # Create instance instead of using global

def set_random_state(self, seed: int) -> None:
def set_random_state(self, seed: int | None) -> None:
"""Set random state for this transform and all nested transforms.
Args:
seed: Random seed to use
"""
if seed is not None:
random.seed(seed) # Set standard library random seed
self.seed = seed
self.random_generator = np.random.default_rng(seed)
self.py_random.seed(seed)
Expand Down

0 comments on commit 4defb14

Please sign in to comment.