Skip to content

Commit

Permalink
Do not use dummy depthwise convolution on GPUs anymore, as performanc…
Browse files Browse the repository at this point in the history
…e on single-channel convolutions is now reasonable in latest CuDNN versions.

PiperOrigin-RevId: 602794501
  • Loading branch information
romanngg committed Feb 1, 2024
1 parent 06c8ad8 commit e91c80e
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3181,26 +3181,21 @@ def get_n_channels(batch_and_channels: int) -> int:
Suggested number of channels for depthwise-separable convolution.
"""
platform = jax.default_backend()
if platform in ['gpu', 'tpu']:
if platform == 'tpu':
n_channels = batch_and_channels

# Find smallest `n_channels > 1` that divides `batch_and_features`; use
# depthwise-separable CNN. For `n_channels == 1` CuDNN appears to invoke a
# different algorithm (`void cudnn::detail::implicit_convolve_sgemm`) than
# in any other case (`conv2d_c1_k1_nchw_hw_packed_kernel`), and the latter
# seems many-fold faster.
# For TPU, start with `n_channels >= 128`. Beware of precision errors:
# TODO(romann): revisit based on b/154160868.
n_channels_min = 2 if platform == 'gpu' else 128
# depthwise-separable CNN. For TPU, start with `n_channels >= 128`.
# Beware of precision errors.
n_channels_min = 128

for n_c in range(n_channels_min, batch_and_channels):
if batch_and_channels % n_c == 0:
n_channels = n_c
break

elif platform == 'cpu':
# For CPU minimal channels seems best. Transpose convolution does not
# support depthwise operations.
elif platform in ('cpu', 'gpu'):
# For CPU and GPU minimal channels seems best.
n_channels = 1

else:
Expand Down

0 comments on commit e91c80e

Please sign in to comment.