Skip to content

Commit

Permalink
Add p argument to RandAugment and TrivialAugmentWide and fix bu…
Browse files Browse the repository at this point in the history
…g in `RandomApply` (#134)

* Add p args to RandAugment and fix bug in RandomApply

* Update

* Add p to TrivialAugmentWide
  • Loading branch information
james77777778 authored Aug 9, 2024
1 parent e3dedff commit b5908a4
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 17 deletions.
2 changes: 1 addition & 1 deletion keras_aug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from keras_aug import visualization
from keras_aug._src.version import version

__version__ = "1.0.0"
__version__ = "1.0.1"
8 changes: 7 additions & 1 deletion keras_aug/_src/layers/composition/random_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ def call(self, inputs):
ops = self.backend
p = self.get_params()

ori_inputs = inputs
if isinstance(inputs, dict):
inputs = inputs.copy()

outputs = ops.core.cond(
p < self.p, lambda: self._apply_transforms(inputs), lambda: inputs
p < self.p,
lambda: self._apply_transforms(inputs),
lambda: ori_inputs,
)
return outputs

Expand Down
10 changes: 8 additions & 2 deletions keras_aug/_src/layers/composition/random_apply_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import testing

from keras_aug._src.layers.composition.random_apply import RandomApply
from keras_aug._src.layers.vision.rand_augment import RandAugment
from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale
from keras_aug._src.layers.vision.resize import Resize
from keras_aug._src.utils.test_utils import get_images
Expand Down Expand Up @@ -81,9 +82,14 @@ def test_config(self):
def test_tf_data_compatibility(self):
import tensorflow as tf

layer = RandomApply(transforms=RandomGrayscale(p=1.0))
def to_dict(x):
return {"images": x, "labels": tf.convert_to_tensor([0, 1])}

layer = RandomApply(transforms=[RandAugment()], p=0.5)
x = get_images("float32", "channels_last")
ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer)
ds = tf.data.Dataset.from_tensor_slices(x).batch(2)
ds = ds.map(to_dict).map(layer)
for output in ds.take(1):
output = output["images"]
self.assertIsInstance(output, tf.Tensor)
self.assertEqual(output.shape, (2, 32, 32, 3))
31 changes: 25 additions & 6 deletions keras_aug/_src/layers/vision/rand_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RandAugment(VisionRandomLayer):
- [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719)
Args:
p: A float specifying the probability. Defaults to `1.0`.
num_ops: Number of augmentation transformations to apply sequentially.
Defaults to `2`.
magnitude: Magnitude for all the transformations. Defaults to `9`.
Expand All @@ -49,6 +50,7 @@ class RandAugment(VisionRandomLayer):

def __init__(
self,
p: float = 1.0,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
Expand All @@ -70,6 +72,7 @@ def __init__(
f"num_magnitude_bins={num_magnitude_bins}"
)

self.p = float(p)
self.num_ops = int(num_ops)
self.magnitude = int(magnitude)
self.num_magnitude_bins = int(num_magnitude_bins)
Expand Down Expand Up @@ -120,7 +123,7 @@ def __init__(
)
p = [1.0] * len(self.augmentation_space)
total = sum(p)
self.p = [prob / total for prob in p]
self.fn_idx_p = [prob / total for prob in p]

def compute_output_shape(self, input_shape):
return input_shape
Expand All @@ -129,17 +132,19 @@ def get_params(self, batch_size, images=None, **kwargs):
ops = self.backend
random_generator = self.random_generator

p = ops.random.uniform([batch_size], seed=random_generator)
magnitude = ops.numpy.full(
[self.num_ops, batch_size], self.magnitude, dtype="int32"
)
fn_idx_p = ops.convert_to_tensor([self.p])
fn_idx_p = ops.convert_to_tensor([self.fn_idx_p])
fn_idx = ops.random.categorical(
ops.numpy.log(fn_idx_p), self.num_ops, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
magnitude=magnitude, # shape: (self.num_ops, batch_size)
fn_idx=fn_idx, # shape: (self.num_ops,)
signed=signed, # shape: (batch_size,)
Expand Down Expand Up @@ -340,20 +345,28 @@ def _apply_images_transform(self, images, magnitude, idx, signed):
return images

def augment_images(self, images, transformations, **kwargs):
ops = self.backend

p = transformations["p"]
magnitude = transformations["magnitude"]
fn_idx = transformations["fn_idx"]
signed = transformations["signed"]
prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3])
for i in range(self.num_ops):
idx = fn_idx[i]
m = magnitude[i]
images = self._apply_images_transform(images, m, idx, signed)
images = ops.numpy.where(
prob,
self._apply_images_transform(images, m, idx, signed),
images,
)
return images

def augment_labels(self, labels, transformations, **kwargs):
return labels

def _apply_bounding_boxes_transform(
self, bounding_boxes, height, width, magnitude, idx, signed
self, bounding_boxes, height, width, p, magnitude, idx, signed
):
ops = self.backend

Expand Down Expand Up @@ -485,7 +498,11 @@ def _apply_bounding_boxes_transform(
width=width,
)
)
boxes = ops.core.switch(idx, transforms, bounding_boxes["boxes"])
prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2])
boxes = bounding_boxes["boxes"]
boxes = ops.numpy.where(
prob, ops.core.switch(idx, transforms, boxes), boxes
)
bounding_boxes = bounding_boxes.copy()
bounding_boxes["boxes"] = boxes
return bounding_boxes
Expand All @@ -505,6 +522,7 @@ def augment_bounding_boxes(
)
ops = self.backend

p = transformations["p"]
magnitude = transformations["magnitude"]
fn_idx = transformations["fn_idx"]
signed = transformations["signed"]
Expand All @@ -523,7 +541,7 @@ def augment_bounding_boxes(
idx = fn_idx[i]
m = magnitude[i]
bounding_boxes = self._apply_bounding_boxes_transform(
bounding_boxes, height, width, m, idx, signed
bounding_boxes, height, width, p, m, idx, signed
)
bounding_boxes = self.bbox_backend.clip_to_images(
bounding_boxes,
Expand All @@ -545,6 +563,7 @@ def get_config(self):
config = super().get_config()
config.update(
{
"p": self.p,
"num_ops": self.num_ops,
"magnitude": self.magnitude,
"num_magnitude_bins": self.num_magnitude_bins,
Expand Down
10 changes: 10 additions & 0 deletions keras_aug/_src/layers/vision/rand_augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def test_config(self):
y2 = layer(x)
self.assertEqual(y.shape, y2.shape)

# Test `p=0.0`
layer = RandAugment(p=0.0)
y = layer(x)

layer = RandAugment.from_config(layer.get_config())
y2 = layer(x)
self.assertAllClose(y, x)
self.assertAllClose(y2, x)
self.assertEqual(y.shape, y2.shape)

def test_tf_data_compatibility(self):
import tensorflow as tf

Expand Down
31 changes: 25 additions & 6 deletions keras_aug/_src/layers/vision/trivial_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TrivialAugmentWide(VisionRandomLayer):
- [TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation](https://arxiv.org/abs/2103.10158)
Args:
p: A float specifying the probability. Defaults to `1.0`.
num_magnitude_bins: The number of different magnitude values. Defaults
to `31`.
geometric: Whether to include geometric augmentations. This
Expand All @@ -46,6 +47,7 @@ class TrivialAugmentWide(VisionRandomLayer):

def __init__(
self,
p: float = 1.0,
num_magnitude_bins: int = 31,
geometric: bool = True,
interpolation: str = "bilinear",
Expand All @@ -56,6 +58,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.p = float(p)
self.num_magnitude_bins = int(num_magnitude_bins)
self.geometric = bool(geometric)
self.interpolation = standardize_interpolation(interpolation)
Expand Down Expand Up @@ -106,7 +109,7 @@ def __init__(
)
p = [1.0] * len(self.augmentation_space)
total = sum(p)
self.p = [prob / total for prob in p]
self.fn_idx_p = [prob / total for prob in p]

def compute_output_shape(self, input_shape):
return input_shape
Expand All @@ -115,15 +118,17 @@ def get_params(self, batch_size, images=None, **kwargs):
ops = self.backend
random_generator = self.random_generator

p = ops.random.uniform([batch_size], seed=random_generator)
magnitude = ops.random.randint([batch_size], 0, self.num_magnitude_bins)
fn_idx_p = ops.convert_to_tensor([self.p])
fn_idx_p = ops.convert_to_tensor([self.fn_idx_p])
fn_idx = ops.random.categorical(
ops.numpy.log(fn_idx_p), 1, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
magnitude=magnitude, # shape: (batch_size,)
fn_idx=fn_idx, # shape: (1,)
signed=signed, # shape: (batch_size,)
Expand Down Expand Up @@ -301,17 +306,25 @@ def _apply_images_transform(self, images, magnitude, idx, signed):
return images

def augment_images(self, images, transformations, **kwargs):
ops = self.backend

p = transformations["p"]
magnitude = transformations["magnitude"]
fn_idx = transformations["fn_idx"][0]
signed = transformations["signed"]
images = self._apply_images_transform(images, magnitude, fn_idx, signed)
prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3])
images = ops.numpy.where(
prob,
self._apply_images_transform(images, magnitude, fn_idx, signed),
images,
)
return images

def augment_labels(self, labels, transformations, **kwargs):
return labels

def _apply_bounding_boxes_transform(
self, bounding_boxes, height, width, magnitude, idx, signed
self, bounding_boxes, height, width, p, magnitude, idx, signed
):
ops = self.backend

Expand Down Expand Up @@ -423,7 +436,11 @@ def _apply_bounding_boxes_transform(
width=width,
)
)
boxes = ops.core.switch(idx, transforms, bounding_boxes["boxes"])
prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2])
boxes = bounding_boxes["boxes"]
boxes = ops.numpy.where(
prob, ops.core.switch(idx, transforms, boxes), boxes
)
bounding_boxes = bounding_boxes.copy()
bounding_boxes["boxes"] = boxes
return bounding_boxes
Expand All @@ -443,6 +460,7 @@ def augment_bounding_boxes(
)
ops = self.backend

p = transformations["p"]
magnitude = transformations["magnitude"]
fn_idx = transformations["fn_idx"][0]
signed = transformations["signed"]
Expand All @@ -458,7 +476,7 @@ def augment_bounding_boxes(
dtype=self.bounding_box_dtype,
)
bounding_boxes = self._apply_bounding_boxes_transform(
bounding_boxes, height, width, magnitude, fn_idx, signed
bounding_boxes, height, width, p, magnitude, fn_idx, signed
)
bounding_boxes = self.bbox_backend.clip_to_images(
bounding_boxes,
Expand All @@ -480,6 +498,7 @@ def get_config(self):
config = super().get_config()
config.update(
{
"p": self.p,
"num_magnitude_bins": self.num_magnitude_bins,
"geometric": self.geometric,
"interpolation": self.interpolation,
Expand Down
10 changes: 10 additions & 0 deletions keras_aug/_src/layers/vision/trivial_augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def test_config(self):
y2 = layer(x)
self.assertEqual(y.shape, y2.shape)

# Test `p=0.0`
layer = TrivialAugmentWide(p=0.0)
y = layer(x)

layer = TrivialAugmentWide.from_config(layer.get_config())
y2 = layer(x)
self.assertAllClose(y, x)
self.assertAllClose(y2, x)
self.assertEqual(y.shape, y2.shape)

def test_tf_data_compatibility(self):
import tensorflow as tf

Expand Down
2 changes: 1 addition & 1 deletion keras_aug/_src/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from keras_aug._src.keras_aug_export import keras_aug_export

__version__ = "1.0.0"
__version__ = "1.0.1"


@keras_aug_export("keras_aug")
Expand Down

0 comments on commit b5908a4

Please sign in to comment.