diff --git a/albumentations/core/composition.py b/albumentations/core/composition.py index 88cf0f5a0..b31965aa3 100644 --- a/albumentations/core/composition.py +++ b/albumentations/core/composition.py @@ -106,7 +106,6 @@ def __init__( transforms: TransformsSeqType, p: float, mask_interpolation: int | None = None, - seed: int | None = None, ): if isinstance(transforms, (BaseCompose, BasicTransform)): warnings.warn( @@ -125,9 +124,9 @@ def __init__( self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {} self._set_keys() self.set_mask_interpolation(mask_interpolation) - self.seed = seed - self.random_generator = np.random.default_rng(seed) - self.set_random_state(seed) # This will propagate to children + self.seed: int | None = None + self.random_generator = np.random.default_rng(self.seed) + self.set_random_state(self.seed) # This will propagate to children def set_random_state(self, seed: int | None) -> None: """Set random state for this compose and all nested transforms. @@ -137,6 +136,7 @@ def set_random_state(self, seed: int | None) -> None: """ self.seed = seed self.random_generator = np.random.default_rng(seed) + if seed is not None: random.seed(seed) # Set standard library random seed diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index f395cb02a..a34c1ad08 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -89,12 +89,14 @@ def __init__(self, p: float = 0.5, always_apply: bool | None = None): self.random_generator = np.random.default_rng(self.seed) self.py_random = random.Random(self.seed) # Create instance instead of using global - def set_random_state(self, seed: int) -> None: + def set_random_state(self, seed: int | None) -> None: """Set random state for this transform and all nested transforms. Args: seed: Random seed to use """ + if seed is not None: + random.seed(seed) # Set standard library random seed self.seed = seed self.random_generator = np.random.default_rng(seed) self.py_random.seed(seed)