diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 5350e725..0dac2a75 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -70,7 +70,7 @@ def shape_safe(x: Any) -> Tuple: Returns: ``x.shape`` if attribute exists, empty tuple otherwise """ - return x.shape if hasattr(x, "shape") else () + return getattr(x, "shape", ()) def dtype_safe(x: Any) -> Union[torch.dtype, Type]: @@ -83,7 +83,7 @@ def dtype_safe(x: Any) -> Union[torch.dtype, Type]: Returns: ``x.dtype`` if attribute exists, type of x otherwise """ - return x.dtype if hasattr(x, "dtype") else type(x) + return getattr(x, "dtype", type(x)) class DPDataLoader(DataLoader):