From e91c80eb3ec2942475644986305e38fda5bf1f6e Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Tue, 30 Jan 2024 11:43:56 -0800 Subject: [PATCH] Do not use dummy depthwise convolution on GPUs anymore, as performance on single-channel convolutions is now reasonable in latest CuDNN versions. PiperOrigin-RevId: 602794501 --- neural_tangents/_src/stax/linear.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/neural_tangents/_src/stax/linear.py b/neural_tangents/_src/stax/linear.py index c5d63059..caa05d3b 100644 --- a/neural_tangents/_src/stax/linear.py +++ b/neural_tangents/_src/stax/linear.py @@ -3181,26 +3181,21 @@ def get_n_channels(batch_and_channels: int) -> int: Suggested number of channels for depthwise-separable convolution. """ platform = jax.default_backend() - if platform in ['gpu', 'tpu']: + if platform == 'tpu': n_channels = batch_and_channels # Find smallest `n_channels > 1` that divides `batch_and_features`; use - # depthwise-separable CNN. For `n_channels == 1` CuDNN appears to invoke a - # different algorithm (`void cudnn::detail::implicit_convolve_sgemm`) than - # in any other case (`conv2d_c1_k1_nchw_hw_packed_kernel`), and the latter - # seems many-fold faster. - # For TPU, start with `n_channels >= 128`. Beware of precision errors: - # TODO(romann): revisit based on b/154160868. - n_channels_min = 2 if platform == 'gpu' else 128 + # depthwise-separable CNN. For TPU, start with `n_channels >= 128`. + # Beware of precision errors. + n_channels_min = 128 for n_c in range(n_channels_min, batch_and_channels): if batch_and_channels % n_c == 0: n_channels = n_c break - elif platform == 'cpu': - # For CPU minimal channels seems best. Transpose convolution does not - # support depthwise operations. + elif platform in ('cpu', 'gpu'): + # For CPU and GPU minimal channels seems best. n_channels = 1 else: