Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ToDType and fix bugs in RandAugment and TrivialAugmentWide #136

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_aug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from keras_aug import visualization
from keras_aug._src.version import version

__version__ = "1.0.1"
__version__ = "1.1.0"
27 changes: 23 additions & 4 deletions keras_aug/_src/backend/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ class ImageBackend(DynamicBackend):
def __init__(self, name=None):
super().__init__(name=name)

def transform_dtype(self, images, from_dtype, to_dtype):
def transform_dtype(self, images, from_dtype, to_dtype, scale=True):
# Ref: torchvision.transforms.v2.ToDtype
ops = self.backend
from_dtype = backend.standardize_dtype(from_dtype)
to_dtype = backend.standardize_dtype(to_dtype)

if from_dtype == to_dtype:
return images
if scale is False:
return ops.cast(images, to_dtype)

is_float_input = backend.is_float_dtype(from_dtype)
is_float_output = backend.is_float_dtype(to_dtype)
Expand Down Expand Up @@ -51,13 +53,30 @@ def transform_dtype(self, images, from_dtype, to_dtype):
num_bits_input = self._num_bits_of_dtype(from_dtype)
num_bits_output = self._num_bits_of_dtype(to_dtype)

def right_shift(inputs, bits):
if self.name == "tensorflow":
import tensorflow as tf

return tf.bitwise.right_shift(inputs, bits)
else:
return inputs >> bits

def left_shift(inputs, bits):
if self.name == "tensorflow":
import tensorflow as tf

return tf.bitwise.left_shift(inputs, bits)
else:
return inputs << bits

if num_bits_input > num_bits_output:
return ops.cast(
images >> (num_bits_input - num_bits_output), to_dtype
right_shift(images, (num_bits_input - num_bits_output)),
to_dtype,
)
else:
return ops.cast(images, to_dtype) << (
num_bits_output - num_bits_input
return left_shift(
ops.cast(images, to_dtype), num_bits_output - num_bits_input
)

def crop(self, images, top, left, height, width, data_format=None):
Expand Down
23 changes: 19 additions & 4 deletions keras_aug/_src/layers/base/vision_random_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ class VisionRandomLayer(keras.Layer):
IS_DICT = "is_dict"
BATCHED = "batched"

SUPPORTED_INT_DTYPES = ("uint8", "int16", "int32")

def __init__(self, has_generator=True, seed=None, **kwargs):
super().__init__(**kwargs)
# Check dtype
if not backend.is_float_dtype(self.compute_dtype):
if self.compute_dtype != "uint8":
if self.compute_dtype not in self.SUPPORTED_INT_DTYPES:
raise ValueError(
"Only floating and 'uint8' are supported for compute dtype."
f" Received: compute_dtype={self.compute_dtype}"
f"Only floating and {self.SUPPORTED_INT_DTYPES} are "
"supported for compute dtype. "
f"Received: compute_dtype={self.compute_dtype}"
)

self._backend = DynamicBackend(backend.backend())
Expand All @@ -99,6 +102,7 @@ def __init__(self, has_generator=True, seed=None, **kwargs):
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
self.autocast = False
self._transform_dtype_scale = True

@property
def image_dtype(self):
Expand All @@ -122,6 +126,14 @@ def backend(self):
def random_generator(self):
return self._random_generator.random_generator

@property
def transform_dtype_scale(self):
return self._transform_dtype_scale

@transform_dtype_scale.setter
def transform_dtype_scale(self, value):
self._transform_dtype_scale = bool(value)

def get_params(
self,
batch_size,
Expand Down Expand Up @@ -389,7 +401,10 @@ def _cast_inputs(self, inputs):
if self.IMAGES in inputs:
inputs[self.IMAGES] = ops.convert_to_tensor(inputs[self.IMAGES])
inputs[self.IMAGES] = self.image_backend.transform_dtype(
inputs[self.IMAGES], inputs[self.IMAGES].dtype, self.image_dtype
inputs[self.IMAGES],
inputs[self.IMAGES].dtype,
self.image_dtype,
scale=self.transform_dtype_scale,
)
if self.LABELS in inputs:
inputs[self.LABELS] = ops.convert_to_tensor(inputs[self.LABELS])
Expand Down
2 changes: 1 addition & 1 deletion keras_aug/_src/layers/vision/rand_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_params(self, batch_size, images=None, **kwargs):
ops.numpy.log(fn_idx_p), self.num_ops, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed_p = ops.random.uniform([batch_size], seed=random_generator) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
Expand Down
56 changes: 56 additions & 0 deletions keras_aug/_src/layers/vision/to_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import keras
from keras import backend

from keras_aug._src.keras_aug_export import keras_aug_export
from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer


@keras_aug_export(parent_path=["keras_aug.layers.vision"])
@keras.saving.register_keras_serializable(package="keras_aug")
class ToDType(VisionRandomLayer):
"""Converts the input to a specific dtype, optionally scaling the values.

If `scale` is `True`, the value range will changed as follows:
- `"uint8"`: `[0, 255]`
- `"int16"`: `[-32768, 32767]`
- `"int32"`: `[-2147483648, 2147483647]`
- float: `[0.0, 1.0]`

Args:
to_dtype: A string specifying the target dtype.
scale: Whether to scale the values. Defaults to `False`.
"""

def __init__(self, to_dtype, scale=False, **kwargs):
to_dtype = backend.standardize_dtype(to_dtype)
self.scale = bool(scale)
if "dtype" in kwargs:
kwargs.pop("dtype")
super().__init__(has_generator=False, dtype=to_dtype, **kwargs)
self.to_dtype = to_dtype
self.transform_dtype_scale = self.scale

def compute_output_shape(self, input_shape):
return input_shape

def augment_images(self, images, transformations, **kwargs):
return images

def augment_labels(self, labels, transformations, **kwargs):
return labels

Check warning on line 40 in keras_aug/_src/layers/vision/to_dtype.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/layers/vision/to_dtype.py#L40

Added line #L40 was not covered by tests

def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs):
return bounding_boxes

Check warning on line 43 in keras_aug/_src/layers/vision/to_dtype.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/layers/vision/to_dtype.py#L43

Added line #L43 was not covered by tests

def augment_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
):
return segmentation_masks

Check warning on line 48 in keras_aug/_src/layers/vision/to_dtype.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/layers/vision/to_dtype.py#L48

Added line #L48 was not covered by tests

def augment_keypoints(self, keypoints, transformations, **kwargs):
return keypoints

Check warning on line 51 in keras_aug/_src/layers/vision/to_dtype.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/layers/vision/to_dtype.py#L51

Added line #L51 was not covered by tests

def get_config(self):
config = super().get_config()
config.update({"to_dtype": self.to_dtype, "scale": self.scale})
return config
102 changes: 102 additions & 0 deletions keras_aug/_src/layers/vision/to_dtype_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import keras
import numpy as np
from absl.testing import parameterized
from keras import backend
from keras.src import testing
from keras.src.testing.test_utils import named_product

from keras_aug._src.layers.vision.to_dtype import ToDType
from keras_aug._src.utils.test_utils import get_images


class ToDTypeTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
# Defaults to channels_last
self.data_format = backend.image_data_format()
backend.set_image_data_format("channels_last")
return super().setUp()

def tearDown(self) -> None:
backend.set_image_data_format(self.data_format)
return super().tearDown()

@parameterized.named_parameters(
named_product(
from_dtype=["uint8", "int16", "int32", "bfloat16", "float32"],
to_dtype=["uint8", "int16", "bfloat16", "float32"],
scale=[True, False],
)
)
def test_correctness(self, from_dtype, to_dtype, scale):
import torch
import torchvision.transforms.v2.functional as TF
from keras.src.backend.torch import to_torch_dtype

# Test channels_last
x = get_images(from_dtype, "channels_last")
layer = ToDType(to_dtype, scale)
y = layer(x)

if from_dtype == "bfloat16":
x = x.astype("float32")
ref_y = TF.to_dtype(
torch.tensor(np.transpose(x, [0, 3, 1, 2])),
dtype=to_torch_dtype(to_dtype),
scale=scale,
)

if to_dtype == "bfloat16":
y = keras.ops.cast(y, "float32")
ref_y = ref_y.to(torch.float32)
to_dtype = "float32"
ref_y = np.transpose(ref_y.cpu().numpy(), [0, 2, 3, 1])
self.assertDType(y, to_dtype)
if from_dtype == "bfloat16" and to_dtype in ("uint8", "int16"):
return
self.assertAllClose(y, ref_y)

def test_shape(self):
# Test dynamic shape
x = keras.KerasTensor((None, None, None, 3))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, None, None, 3))
backend.set_image_data_format("channels_first")
x = keras.KerasTensor((None, 3, None, None))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, 3, None, None))

# Test static shape
backend.set_image_data_format("channels_last")
x = keras.KerasTensor((None, 32, 32, 3))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, 32, 32, 3))
backend.set_image_data_format("channels_first")
x = keras.KerasTensor((None, 3, 32, 32))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, 3, 32, 32))

def test_model(self):
layer = ToDType("float32", scale=True)
inputs = keras.layers.Input(shape=[None, None, 5])
outputs = layer(inputs)
model = keras.models.Model(inputs, outputs)
self.assertEqual(model.output_shape, (None, None, None, 5))

def test_config(self):
x = get_images("float32", "channels_last")
layer = ToDType("float32", scale=True)
y = layer(x)

layer = ToDType.from_config(layer.get_config())
y2 = layer(x)
self.assertAllClose(y, y2)

def test_tf_data_compatibility(self):
import tensorflow as tf

layer = ToDType("float32", scale=True)
x = get_images("float32", "channels_last")
ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer)
for output in ds.take(1):
self.assertIsInstance(output, tf.Tensor)
self.assertEqual(output.shape, (2, 32, 32, 3))
6 changes: 4 additions & 2 deletions keras_aug/_src/layers/vision/trivial_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def get_params(self, batch_size, images=None, **kwargs):
random_generator = self.random_generator

p = ops.random.uniform([batch_size], seed=random_generator)
magnitude = ops.random.randint([batch_size], 0, self.num_magnitude_bins)
magnitude = ops.random.randint(
[batch_size], 0, self.num_magnitude_bins, seed=random_generator
)
fn_idx_p = ops.convert_to_tensor([self.fn_idx_p])
fn_idx = ops.random.categorical(
ops.numpy.log(fn_idx_p), 1, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed_p = ops.random.uniform([batch_size], seed=random_generator) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
Expand Down
6 changes: 4 additions & 2 deletions keras_aug/_src/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@


@keras_aug_export(parent_path=["keras_aug.ops.image"])
def transform_dtype(images, from_dtype, to_dtype):
def transform_dtype(images, from_dtype, to_dtype, scale=True):
backend = "tensorflow" if in_tf_graph() else None
return ImageBackend(backend).transform_dtype(images, from_dtype, to_dtype)
return ImageBackend(backend).transform_dtype(

Check warning on line 10 in keras_aug/_src/ops/image.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/ops/image.py#L10

Added line #L10 was not covered by tests
images, from_dtype, to_dtype, scale=scale
)


@keras_aug_export(parent_path=["keras_aug.ops.image"])
Expand Down
8 changes: 8 additions & 0 deletions keras_aug/_src/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ def get_images(dtype, data_format="channels_first", size=(32, 32)):
x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype)
elif dtype == "bfloat16":
x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype)
elif dtype == "float16":
x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype)
elif dtype == "uint8":
x = np.random.uniform(0, 255, (2, 3, *size)).astype(dtype)
elif dtype == "int8":
x = np.random.uniform(-128, 127, (2, 3, *size)).astype(dtype)
elif dtype == "int16":
x = np.random.uniform(-32768, 32767, (2, 3, *size)).astype(dtype)
elif dtype == "int32":
x = np.random.uniform(-2147483648, 2147483647, (2, 3, *size)).astype(
dtype
)
if data_format == "channels_last":
x = np.transpose(x, [0, 2, 3, 1])
return x
Expand Down
2 changes: 1 addition & 1 deletion keras_aug/_src/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from keras_aug._src.keras_aug_export import keras_aug_export

__version__ = "1.0.1"
__version__ = "1.1.0"


@keras_aug_export("keras_aug")
Expand Down
1 change: 1 addition & 0 deletions keras_aug/layers/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@
from keras_aug._src.layers.vision.random_solarize import RandomSolarize
from keras_aug._src.layers.vision.rescale import Rescale
from keras_aug._src.layers.vision.resize import Resize
from keras_aug._src.layers.vision.to_dtype import ToDType
from keras_aug._src.layers.vision.trivial_augment import TrivialAugmentWide