diff --git a/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py new file mode 100644 index 00000000000..e0f5587a65a --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py @@ -0,0 +1,216 @@ +import functools +import operator +import os +import os.path +import sys +import numpy as np +import lbann.contrib.args + +# CI utilities +current_file = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), "common_python")) +import tools + +# ============================================== +# Objects for Python data reader +# ============================================== +# Note: The Python data reader imports this file as a module and calls +# the functions below to ingest data. + +# Data +np.random.seed(20200115) +_num_samples = 15 +_sample_dims = (15, 5, 1) +_sample_size = functools.reduce(operator.mul, _sample_dims) +_samples = np.random.normal(loc=0.5, size=(_num_samples, _sample_size)).astype( + np.float32 +) + + +# Sample access functions +def get_sample(index): + return _samples[index, :] + + +def num_samples(): + return _num_samples + + +def sample_dims(): + return (_sample_size,) + + +# ============================================== +# NumPy implementation +# ============================================== + + +def numpy_channelwise_softmax(x): + if x.dtype is not np.float64: + x = x.astype(np.float64) + axis = tuple(range(1, x.ndim)) + shift = np.max(x, axis=axis, keepdims=True) + y = np.exp(x - shift) + return y / np.sum(y, axis=axis, keepdims=True) + + +# ============================================== +# Setup LBANN experiment +# ============================================== + + +def setup_experiment(lbann, weekly): + """Construct LBANN experiment. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + mini_batch_size = num_samples() // 2 + trainer = lbann.Trainer(mini_batch_size) + model = construct_model(lbann) + data_reader = construct_data_reader(lbann) + optimizer = lbann.NoOptimizer() + return ( + trainer, + model, + data_reader, + optimizer, + None, + ) # Don't request any specific number of nodes + + +def create_parallel_strategy(num_channel_groups): + return {"channel_groups": num_channel_groups, "filter_groups": num_channel_groups} + + +def construct_model(lbann): + """Construct LBANN model. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Input data + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights( + optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name="input_weights", + ) + x = lbann.Sum( + lbann.Reshape(lbann.Input(data_field="samples"), dims=_sample_dims), + lbann.WeightsLayer(weights=x_weights, dims=_sample_dims), + ) + x_lbann = x + obj = [] + metrics = [] + callbacks = [] + + num_channel_groups = tools.gpus_per_node(lbann) + if num_channel_groups == 0: + e = "this test requires GPUs." + print("Skip - " + e) + pytest.skip(e) + + # ------------------------------------------ + # Data-parallel layout + # ------------------------------------------ + + # LBANN implementation + x = x_lbann + + y = lbann.ChannelwiseSoftmax( + x, + data_layout="data_parallel", + parallel_strategy=create_parallel_strategy(num_channel_groups), + name="Channelwise_softmax_distconv", + ) + z = lbann.L2Norm2(y) + obj.append(z) + metrics.append(lbann.Metric(z, name="channelwise split distconv")) + + # NumPy implementation + vals = [] + for i in range(num_samples()): + x = get_sample(i).reshape(_sample_dims).astype(np.float64) + y = numpy_channelwise_softmax(x) + z = tools.numpy_l2norm2(y) + vals.append(z) + val = np.mean(vals) + tol = 8 * val * np.finfo(np.float32).eps + callbacks.append( + lbann.CallbackCheckMetric( + metric=metrics[-1].name, + lower_bound=val - tol, + upper_bound=val + tol, + error_on_failure=True, + execution_modes="test", + ) + ) + + # ------------------------------------------ + # Gradient checking + # ------------------------------------------ + + callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + + # ------------------------------------------ + # Construct model + # ------------------------------------------ + + num_epochs = 0 + return lbann.Model( + num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks, + ) + + +def construct_data_reader(lbann): + """Construct Protobuf message for Python data reader. + + The Python data reader will import the current Python file to + access the sample access functions. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Note: The training data reader should be removed when + # https://github.com/LLNL/lbann/issues/1098 is resolved. + message = lbann.reader_pb2.DataReader() + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "train" + ) + ] + ) + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "test" + ) + ] + ) + return message + + +# ============================================== +# Setup PyTest +# ============================================== + +# Create test functions that can interact with PyTest +for _test_func in tools.create_tests( + setup_experiment, + __file__, + environment=lbann.contrib.args.get_distconv_environment(), +): + globals()[_test_func.__name__] = _test_func diff --git a/include/lbann/layers/CMakeLists.txt b/include/lbann/layers/CMakeLists.txt index fa6442b3bc8..f32d5766043 100644 --- a/include/lbann/layers/CMakeLists.txt +++ b/include/lbann/layers/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. diff --git a/include/lbann/layers/misc/CMakeLists.txt b/include/lbann/layers/misc/CMakeLists.txt index d84c4e0accf..18da384a7ad 100644 --- a/include/lbann/layers/misc/CMakeLists.txt +++ b/include/lbann/layers/misc/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. @@ -40,5 +40,8 @@ set_full_path(THIS_DIR_HEADERS variance.hpp ) +if (LBANN_HAS_DISTCONV) + add_subdirectory(distconv) +endif() # Propagate the files up the tree set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) diff --git a/include/lbann/layers/misc/channelwise_softmax.hpp b/include/lbann/layers/misc/channelwise_softmax.hpp index 7f0e792acae..086ace20b73 100644 --- a/include/lbann/layers/misc/channelwise_softmax.hpp +++ b/include/lbann/layers/misc/channelwise_softmax.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. @@ -32,8 +32,43 @@ #include "lbann/proto/layers.pb.h" +#ifdef LBANN_HAS_DISTCONV +#include "lbann/layers/data_type_distconv_adapter.hpp" +#include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" +#endif + namespace lbann { +#ifdef LBANN_HAS_DISTCONV +namespace dc { +template +using ChannelwiseSoftmax = + ::distconv::ChannelwiseSoftmax; +} // namespace dc + +template +class channelwise_softmax_distconv_adapter + : public data_type_distconv_adapter +{ +public: + using TensorDevType = + typename data_type_distconv_adapter::TensorDevType; + + channelwise_softmax_distconv_adapter(Layer& layer) + : data_type_distconv_adapter(layer) + {} + + virtual ~channelwise_softmax_distconv_adapter() = default; + void setup_distributions(tensor_overlap_constraints& constraints) override; + void setup_layer(size_t workspace_capacity) override; + void fp_compute(); + void bp_compute(); + std::unique_ptr> + m_channelwise_softmax_operator; +}; // class definition channelwise_softmax_distconv_adapter + +#endif // LBANN_HAS_DISTCONV + /** @brief Apply softmax to tensor channels. * * The input tensor is sliced along the first tensor dimension (the @@ -93,6 +128,19 @@ class channelwise_softmax_layer : public data_type_layer void fp_compute() override; void bp_compute() override; +#ifdef LBANN_HAS_DISTCONV + friend class channelwise_softmax_distconv_adapter; + +protected: + void setup_distconv_adapter() override; + bool is_distconv_supported() const override; + channelwise_softmax_distconv_adapter& + get_distconv_adapter() override; + const channelwise_softmax_distconv_adapter& + get_distconv_adapter() const override; +#endif // LBANN_HAS_DISTCONV private: void get_channel_size_and_stride(El::Int& channel_size, El::Int& channel_stride, @@ -159,9 +207,115 @@ El::Device channelwise_softmax_layer:: return Device; } +#ifdef LBANN_HAS_DISTCONV + // ========================================================= -// Explicit template instantiation +// DistConv-Adapter member functions // ========================================================= +template +void channelwise_softmax_distconv_adapter:: + setup_distributions(tensor_overlap_constraints& constraints) +{ + data_type_distconv_adapter::setup_distributions(constraints); + + for (auto& d : this->m_prev_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_prev_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } +} + +template +void channelwise_softmax_distconv_adapter:: + setup_layer(size_t workspace_capacity) +{ + data_type_distconv_adapter::setup_layer(workspace_capacity); + + m_channelwise_softmax_operator = + std::make_unique>(dc::get_backend()); +} + +template +void channelwise_softmax_distconv_adapter:: + fp_compute() +{ + auto& layer = + dynamic_cast&>( + this->layer()); + m_channelwise_softmax_operator->forward(this->get_prev_activations(0), + this->get_activations(0)); +} + +template +void channelwise_softmax_distconv_adapter:: + bp_compute() +{ + auto& layer = + dynamic_cast&>( + this->layer()); + m_channelwise_softmax_operator->backward(this->get_activations(0), + this->get_prev_error_signals(), + this->get_error_signals(0)); +} +// ============================================================= +// DistConv-enabled Channelwise-Softmax member functions +// ============================================================= + +template +bool channelwise_softmax_layer:: + is_distconv_supported() const +{ + return Device == El::Device::GPU && Layout == data_layout::DATA_PARALLEL; +} + +template +void channelwise_softmax_layer:: + setup_distconv_adapter() +{ + this->get_distconv_adapter_ptr() = std::make_unique< + channelwise_softmax_distconv_adapter>( + *this); +} + +template +const channelwise_softmax_distconv_adapter& +channelwise_softmax_layer:: + get_distconv_adapter() const +{ + return dynamic_cast&>( + data_type_layer::get_distconv_adapter()); +} + +template +channelwise_softmax_distconv_adapter& +channelwise_softmax_layer:: + get_distconv_adapter() +{ + return const_cast< + channelwise_softmax_distconv_adapter&>( + static_cast< + const channelwise_softmax_layer&>(*this) + .get_distconv_adapter()); +} + +#endif // LBANN_HAS_DISTCONV #ifndef LBANN_CHANNELWISE_SOFTMAX_LAYER_INSTANTIATE #define PROTO_DEVICE(T, Device) \ diff --git a/include/lbann/layers/misc/channelwise_softmax_impl.hpp b/include/lbann/layers/misc/channelwise_softmax_impl.hpp index b13095d9fcc..6879eb294b5 100644 --- a/include/lbann/layers/misc/channelwise_softmax_impl.hpp +++ b/include/lbann/layers/misc/channelwise_softmax_impl.hpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. @@ -53,6 +53,28 @@ void channelwise_softmax_layer::setup_dims() } this->set_output_dims(this->get_input_dims()); +#ifdef LBANN_HAS_DISTCONV + + if (this->distconv_enabled()) { + // Additional checks when distconv mode is enabled + const auto& input_dims = this->get_input_dims(); + const auto& output_dims = this->get_output_dims(); + + if (input_dims.size() != 3 || output_dims.size() != 3) { + LBANN_ERROR( + this->get_type(), + " layer \"", + this->get_name(), + "\" ", + "expects an input and output tensor with 3 dimensions (channel, *, *), " + "but it has been configured as a ", + input_dims.size(), + "-D input tensor and ", + output_dims.size(), + "-D output tensor"); + } + } +#endif // LBANN_HAS_DISTCONV } template diff --git a/include/lbann/layers/misc/distconv/CMakeLists.txt b/include/lbann/layers/misc/distconv/CMakeLists.txt new file mode 100644 index 00000000000..fc835a868a5 --- /dev/null +++ b/include/lbann/layers/misc/distconv/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +## Produced at the Lawrence Livermore National Laboratory. +## Written by the LBANN Research Team (B. Van Essen, et al.) listed in +## the CONTRIBUTORS file. +## +## LLNL-CODE-697807. +## All rights reserved. +## +## This file is part of LBANN: Livermore Big Artificial Neural Network +## Toolkit. For details, see http://software.llnl.gov/LBANN or +## https://github.com/LLNL/LBANN. +## +## Licensed under the Apache License, Version 2.0 (the "Licensee"); you +## may not use this file except in compliance with the License. You may +## obtain a copy of the License at: +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +## implied. See the License for the specific language governing +## permissions and limitations under the license. +################################################################################ +set_full_path(THIS_DIR_HEADERS + distconv_channelwise_softmax.hpp + ) + +# Propagate the files up the tree +set(HEADERS "${HEADERS}" "${THIS_DIR_HEADERS}" PARENT_SCOPE) diff --git a/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp new file mode 100644 index 00000000000..4ca40035f84 --- /dev/null +++ b/include/lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp @@ -0,0 +1,60 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX +#define LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX +#include "lbann/utils/distconv.hpp" + +#ifdef LBANN_HAS_DISTCONV +namespace distconv { +template +class ChannelwiseSoftmax +{ + using LocaleMPI = tensor::LocaleMPI; + +public: + ChannelwiseSoftmax(Backend& backend) : m_be(backend){}; + + template + int forward(const tensor::Tensor& input_0, + tensor::Tensor& output); + + template + int backward( + const tensor::Tensor& input_0, + const tensor::Tensor& output_grad, + tensor::Tensor& input_grad_0); + +protected: + Backend& m_be; +}; + +extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, float>; +extern template class ChannelwiseSoftmax<::distconv::BackendDNNLib, double>; +} // namespace distconv + +#endif // LBANN_HAS_DISTCONV +#endif // LBANN_LAYERS_MISC_DISTCONV_CHANNELWISE_SOFTMAX \ No newline at end of file diff --git a/scripts/superbuild/nccl/CMakeLists.txt b/scripts/superbuild/nccl/CMakeLists.txt index 5b8e69e0416..354700e3f4b 100644 --- a/scripts/superbuild/nccl/CMakeLists.txt +++ b/scripts/superbuild/nccl/CMakeLists.txt @@ -48,7 +48,7 @@ endmacro () lbann_sb_init_extern_pkg( NAME NCCL - LANGUAGES C CXX # CUDA <- can't set explicitly; inferred from ${CUDA_HOME} + LANGUAGES C CXX CUDA GITHUB_URL NVIDIA/nccl GIT_TAG "master") @@ -105,8 +105,8 @@ if (LBANN_SB_FWD_NCCL_NVCC_GENCODE) elseif (DEFINED $ENV{NVCC_GENCODE}) set(_nccl_nvcc_gencode_opt "NVCC_GENCODE=$ENV{NVCC_GENCODE}") -elseif (LBANN_NCCL_CUDA_ARCHITECTURES) - set(_cuda_arch ${LBANN_NCCL_CUDA_ARCHITECTURES}) +elseif (LBANN_SB_NCCL_CUDA_ARCHITECTURES) + set(_cuda_arch ${LBANN_SB_NCCL_CUDA_ARCHITECTURES}) set(_nccl_nvcc_gencode_opt "NVCC_GENCODE=-gencode=arch=compute_${_cuda_arch},code=sm_${_cuda_arch}") else () diff --git a/src/layers/distconv_adapter.cpp b/src/layers/distconv_adapter.cpp index b4fe2337820..b91ca6a516c 100644 --- a/src/layers/distconv_adapter.cpp +++ b/src/layers/distconv_adapter.cpp @@ -263,9 +263,10 @@ void distconv_adapter::adjust_parallel_strategy() } } - else if (layer_type == "channel-wise fully-connected" || - layer_type == "matmul") { - if (c != f) { + else if (layer_type == "channel-wise fully-connected" + || layer_type == "matmul" + || layer_type == "channel-wise softmax"){ + if (c != f){ if (layer().get_comm()->am_trainer_master()) { LBANN_WARNING("The number of channel and filter decomposition should " "be the same. Setting", diff --git a/src/layers/misc/CMakeLists.txt b/src/layers/misc/CMakeLists.txt index 77b50a02d7b..8bb7c092cfa 100644 --- a/src/layers/misc/CMakeLists.txt +++ b/src/layers/misc/CMakeLists.txt @@ -1,5 +1,5 @@ ################################################################################ -## Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. ## Produced at the Lawrence Livermore National Laboratory. ## Written by the LBANN Research Team (B. Van Essen, et al.) listed in ## the CONTRIBUTORS file. @@ -38,7 +38,6 @@ set_full_path(THIS_DIR_SOURCES rowwise_weights_norms.cpp uniform_hash.cpp variance.cpp - misc_builders.cpp ) @@ -57,15 +56,21 @@ if (LBANN_HAS_GPU) rowwise_weights_norms.cu uniform_hash.cu variance.cu + channelwise_softmax_kernels.cuh ) if (LBANN_HAS_FFTW) list(APPEND THIS_DIR_CU_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/dft_abs.cu") endif () endif () +if (LBANN_HAS_DISTCONV) + add_subdirectory(distconv) +endif() + # Add the subdirectories add_subdirectory(cereal_registration) + # Propagate the files up the tree set(SOURCES "${SOURCES}" "${THIS_DIR_SOURCES}" PARENT_SCOPE) set(GPU_SOURCES "${GPU_SOURCES}" "${THIS_DIR_CU_SOURCES}" PARENT_SCOPE) diff --git a/src/layers/misc/channelwise_softmax.cpp b/src/layers/misc/channelwise_softmax.cpp index f335a733bc0..29af732628f 100644 --- a/src/layers/misc/channelwise_softmax.cpp +++ b/src/layers/misc/channelwise_softmax.cpp @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. diff --git a/src/layers/misc/channelwise_softmax.cu b/src/layers/misc/channelwise_softmax.cu index 083f4f55a7d..d75fdf6e1e7 100644 --- a/src/layers/misc/channelwise_softmax.cu +++ b/src/layers/misc/channelwise_softmax.cu @@ -1,5 +1,5 @@ //////////////////////////////////////////////////////////////////////////////// -// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. // Produced at the Lawrence Livermore National Laboratory. // Written by the LBANN Research Team (B. Van Essen, et al.) listed in // the CONTRIBUTORS file. @@ -25,527 +25,70 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_CHANNELWISE_SOFTMAX_LAYER_INSTANTIATE +#include "channelwise_softmax_kernels.cuh" #include "lbann/layers/misc/channelwise_softmax_impl.hpp" #include "lbann/utils/gpu/helpers.hpp" namespace lbann { -namespace { - -using Size3 = gpu_lib::array; - -/** @brief Max functor */ -template -struct max_op -{ - __device__ __forceinline__ DataType operator()(const T& x1, const T& x2) const - { - return gpu_lib::max(x1, x2); - } -}; - -} // namespace - -// ========================================================= -// Forward prop -// ========================================================= - -namespace { - -/** @brief Max reduction over last dimension of 3D tensor. - * - * Each CUDA block computes the max over a subset of tensor entries - * in @c vals and outputs the result to @c maxvals. This should be - * repeated multiple times to fully reduce the last tensor dimension. - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (vals_dims[2] / bdimx) x vals_dims[1] x vals_dims[0] - * - * maxvals: vals_dims[0] x vals_dims[1] x (vals_dims[2] / bdimx) - */ -template -__global__ void fp_max_kernel(Size3 vals_dims, - const TensorDataType* __restrict__ vals_buffer, - Size3 vals_strides, - TensorDataType* __restrict__ maxvals_buffer, - Size3 maxvals_strides) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t bidx = blockIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < vals_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < vals_dims[1]; j += nthreadsy) { - - // Find largest value for each thread - TensorDataType maxval{-gpu_lib::infinity()}; - for (size_t i = gidx; i < vals_dims[2]; i += nthreadsx) { - const auto& val = - vals_buffer[k * vals_strides[0] + j * vals_strides[1] + - i * vals_strides[2]]; - maxval = gpu_lib::max(maxval, val); - } - - // Find largest value for each block - maxval = gpu_lib::block_reduce>(maxval); - if (tid == 0) { - const auto& pos = (k * maxvals_strides[0] + j * maxvals_strides[1] + - bidx * maxvals_strides[2]); - maxvals_buffer[pos] = maxval; - } - } - } -} - -/** Compute softmax denominator. - * - * denom = sum( exp(x_i-shift) ) - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (input_dims[2] / bdimx) x input_dims[1] x input_dims[0] - * - * shifts and denoms are fully-packed 2D tensors with dimensions of - * input_dims[0] x input_dims[1]. - */ -template -__global__ void fp_denom_kernel(Size3 input_dims, - const TensorDataType* __restrict__ input_buffer, - Size3 input_strides, - const TensorDataType* __restrict__ shifts, - TensorDataType* __restrict__ denoms) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { - - // Compute contribution from each thread - const auto& shift = shifts[j + k * input_dims[1]]; - TensorDataType denom{0.}; - for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { - const auto& x = - input_buffer[k * input_strides[0] + j * input_strides[1] + - i * input_strides[2]]; - denom += gpu_lib::exp(x - shift); - } - - // Compute contribution from each block - denom = gpu_lib::block_reduce(denom); - if (tid == 0) { - if (gridDim.x > 1) - gpu_lib::atomic_add(&denoms[j + k * input_dims[1]], denom); - else - denoms[j + k * input_dims[1]] = denom; - } - } - } -} - -/** Compute softmax. - * - * y_i = exp(x_i-shift) / denom - * - * Block dimensions: bdimx x bdimy x bdimz - * - * Grid dimensions: (input_dims[2] / bdimx) x (input_dims[1] / bdimy) x - * (input_dims[0] / bdimz) - * - * shifts and denoms are fully-packed 2D tensors with dimensions of - * input_dims[0] x input_dims[1]. - */ -template -__global__ void -fp_output_kernel(Size3 input_dims, - const TensorDataType* __restrict__ input_buffer, - Size3 input_strides, - TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ shifts, - const TensorDataType* __restrict__ denoms) +template +void channelwise_softmax_layer::fp_compute() { - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { - const auto& shift = shifts[j + k * input_dims[1]]; - const auto& denom = denoms[j + k * input_dims[1]]; - for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { - const auto& x = - input_buffer[k * input_strides[0] + j * input_strides[1] + - i * input_strides[2]]; - auto& y = output_buffer[k * output_strides[0] + j * output_strides[1] + - i * output_strides[2]]; - y = gpu_lib::exp(x - shift) / denom; - } - } +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().fp_compute(); + return; } -} - -/** @brief Forward prop */ -template -void fp_impl(size_t num_channels, - size_t channel_size, - size_t channel_stride, - const El::AbstractDistMatrix& input, - El::AbstractDistMatrix& output) -{ +#endif // LBANN_HAS_DISTCONV // Local matrices + const size_t num_channels = this->get_output_dims().front(); + const size_t channel_size = this->get_output_size() / num_channels; using LocalMat = El::Matrix; - const auto& local_input = dynamic_cast(input.LockedMatrix()); - auto& local_output = dynamic_cast(output.Matrix()); - - auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_output), - gpu::get_sync_info(local_input)); - - // Dimensions - const size_t local_mini_batch_size = local_input.Width(); - // const Size3 input_dims{local_mini_batch_size, num_channels, channel_size}; - - // Compute softmax shifts - LocalMat local_shifts; - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - LocalMat maxvals(grid_dims.x * num_channels, local_mini_batch_size); - hydrogen::gpu::LaunchKernel( - fp_max_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_stride, 1}, - maxvals.Buffer(), - Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); - while (grid_dims.x > 1) { - const size_t prev_dim = grid_dims.x; - grid_dims.x = (prev_dim + block_size - 1) / block_size; - const LocalMat prev_maxvals(std::move(maxvals)); - maxvals.Resize(grid_dims.x * num_channels, local_mini_batch_size); - hydrogen::gpu::LaunchKernel( - fp_max_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, prev_dim}, - prev_maxvals.LockedBuffer(), - Size3{static_cast(prev_maxvals.LDim()), prev_dim, 1}, - maxvals.Buffer(), - Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); - } - local_shifts = std::move(maxvals); - } - - // Compute softmax denominators - LocalMat local_denoms(num_channels, local_mini_batch_size); - El::Zero(local_denoms); - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - - // Simple heuristic to switch between atomic softmax denominator vs. - // sequentially accumulating, block-reducing - int sequential_sum_batch = (channel_size + block_size - 1) / block_size; - // The below threshold value has nothing to do with block size - if (sequential_sum_batch < 256) - grid_dims.x = 1; - else - grid_dims.x = sequential_sum_batch; - - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - fp_denom_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_stride, 1}, - local_shifts.LockedBuffer(), - local_denoms.Buffer()); - } - - // Compute softmax - if (!local_input.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - fp_output_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_input.LockedBuffer(), - Size3{static_cast(local_input.LDim()), channel_stride, 1}, - local_output.Buffer(), - Size3{static_cast(local_output.LDim()), channel_stride, 1}, - local_shifts.LockedBuffer(), - local_denoms.LockedBuffer()); - } -} - -} // namespace - -template -void channelwise_softmax_layer::fp_compute() -{ - El::Int num_channels, channel_size, channel_stride; - this->get_channel_size_and_stride(channel_size, channel_stride, num_channels); - fp_impl(num_channels, - channel_size, - channel_stride, - this->get_prev_activations(), - this->get_activations()); + const auto& local_input = + dynamic_cast(this->get_prev_activations().LockedMatrix()); + auto& local_output = + dynamic_cast(this->get_activations().Matrix()); + + channelwise_softmax_fp_impl(num_channels, + channel_size, + local_input, + local_output); } // ========================================================= // Backprop // ========================================================= -namespace { - -/** Compute dot product between output and gradient w.r.t. output. - * - * Block dimensions: bdimx x 1 x 1 - * - * Grid dimensions: (output_dims[2] / bdimx) x output_dims[1] x output_dims[0] - * - * y_dot_dy is a fully-packed 2D tensor with dimensions of - * output_dims[0] x output_dims[1]. - */ -template -__global__ void -bp_y_dot_dy_kernel(Size3 output_dims, - const TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ output_grad_buffer, - Size3 output_grad_strides, - TensorDataType* __restrict__ y_dot_dy) -{ - - // Indices and dimensions - constexpr size_t bdimy = 1; - constexpr size_t bdimz = 1; - const size_t tid = threadIdx.x; - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - - for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { - - // Compute contribution from each thread - TensorDataType _y_dot_dy{0.}; - for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { - const auto& y = - output_buffer[k * output_strides[0] + j * output_strides[1] + - i * output_strides[2]]; - const auto& dy = output_grad_buffer[k * output_grad_strides[0] + - j * output_grad_strides[1] + - i * output_grad_strides[2]]; - _y_dot_dy += y * dy; - } - - // Compute contribution from each block - _y_dot_dy = gpu_lib::block_reduce(_y_dot_dy); - if (tid == 0) { - gpu_lib::atomic_add(&y_dot_dy[j + k * output_dims[1]], _y_dot_dy); - } - } - } -} - -/** Compute gradient w.r.t. input. - * - * dL/dx_i = y_i * ( dL/dy_i - dot(y,dL/dy) ) - * - * Block dimensions: bdimx x bdimy x bdimz - * - * Grid dimensions: (output_dims[2] / bdimx) x (output_dims[1] / bdimy) x - * (output_dims[0] / bdimz) - * - * y_dot_dy is a fully-packed 2D tensor with dimensions of - * output_dims[0] x output_dims[1]. - */ -template -__global__ void -bp_input_grad_kernel(Size3 output_dims, - const TensorDataType* __restrict__ output_buffer, - Size3 output_strides, - const TensorDataType* __restrict__ output_grad_buffer, - Size3 output_grad_strides, - TensorDataType* __restrict__ input_grad_buffer, - Size3 input_grad_strides, - const TensorDataType* __restrict__ y_dot_dy) +template +void channelwise_softmax_layer::bp_compute() { - const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; - const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; - const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; - const size_t nthreadsx = blockDim.x * gridDim.x; - const size_t nthreadsy = blockDim.y * gridDim.y; - const size_t nthreadsz = blockDim.z * gridDim.z; - for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { - for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { - const auto& _y_dot_dy = y_dot_dy[j + k * output_dims[1]]; - for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { - const auto& y = - output_buffer[k * output_strides[0] + j * output_strides[1] + - i * output_strides[2]]; - const auto& dy = output_grad_buffer[k * output_grad_strides[0] + - j * output_grad_strides[1] + - i * output_grad_strides[2]]; - auto& dx = input_grad_buffer[k * input_grad_strides[0] + - j * input_grad_strides[1] + - i * input_grad_strides[2]]; - dx = y * (dy - _y_dot_dy); - } - } +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().bp_compute(); + return; } -} +#endif // LBANN_HAS_DISTCONV -/** @brief Backprop */ -template -void bp_impl(size_t num_channels, - size_t channel_size, - size_t channel_stride, - const El::AbstractDistMatrix& output, - const El::AbstractDistMatrix& output_grad, - El::AbstractDistMatrix& input_grad) -{ + const size_t num_channels = this->get_output_dims().front(); + const size_t channel_size = this->get_output_size() / num_channels; // Local matrices using LocalMat = El::Matrix; const auto& local_output = - dynamic_cast(output.LockedMatrix()); - const auto& local_output_grad = - dynamic_cast(output_grad.LockedMatrix()); - auto& local_input_grad = dynamic_cast(input_grad.Matrix()); - - // Dimensions - const size_t local_mini_batch_size = local_output.Width(); - - // dot(y,dL/dy) - LocalMat local_y_dot_dy(num_channels, local_mini_batch_size); - El::Zero(local_y_dot_dy); - - auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_y_dot_dy), - gpu::get_sync_info(local_output_grad), - gpu::get_sync_info(local_output), - gpu::get_sync_info(local_input_grad)); - - if (!local_output.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - bp_y_dot_dy_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_output.LockedBuffer(), - Size3{static_cast(local_output.LDim()), channel_stride, 1}, - local_output_grad.LockedBuffer(), - Size3{static_cast(local_output_grad.LDim()), channel_stride, 1}, - local_y_dot_dy.Buffer()); - } - - // Compute gradient w.r.t. input - if (!local_output.IsEmpty()) { - constexpr size_t block_size = 256; - dim3 block_dims, grid_dims; - block_dims.x = block_size; - grid_dims.x = (channel_size + block_size - 1) / block_size; - grid_dims.y = num_channels; - grid_dims.z = local_mini_batch_size; - gpu_lib::clip_grid_dims(grid_dims); - hydrogen::gpu::LaunchKernel( - bp_input_grad_kernel, - grid_dims, - block_dims, - 0, - multisync, - Size3{local_mini_batch_size, num_channels, channel_size}, - local_output.LockedBuffer(), - Size3{static_cast(local_output.LDim()), channel_stride, 1}, - local_output_grad.LockedBuffer(), - Size3{static_cast(local_output_grad.LDim()), channel_stride, 1}, - local_input_grad.Buffer(), - Size3{static_cast(local_input_grad.LDim()), channel_stride, 1}, - local_y_dot_dy.LockedBuffer()); - } -} - -} // namespace - -template -void channelwise_softmax_layer::bp_compute() -{ - El::Int num_channels, channel_size, channel_stride; - this->get_channel_size_and_stride(channel_size, channel_stride, num_channels); - bp_impl(num_channels, - channel_size, - channel_stride, - this->get_activations(), - this->get_prev_error_signals(), - this->get_error_signals()); + dynamic_cast(this->get_activations().LockedMatrix()); + const auto& local_output_grad = dynamic_cast( + this->get_prev_error_signals().LockedMatrix()); + auto& local_input_grad = + dynamic_cast(this->get_error_signals().Matrix()); + + channelwise_softmax_bp_impl(num_channels, + channel_size, + local_output, + local_output_grad, + local_input_grad); } // ========================================================= diff --git a/src/layers/misc/channelwise_softmax_kernels.cuh b/src/layers/misc/channelwise_softmax_kernels.cuh new file mode 100644 index 00000000000..278e43bed6c --- /dev/null +++ b/src/layers/misc/channelwise_softmax_kernels.cuh @@ -0,0 +1,472 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// +#ifndef LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS +#define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS +#include "lbann/utils/gpu/helpers.hpp" +namespace lbann{ +namespace{ +using Size3 = gpu_lib::array; + +/** @brief Max functor */ +template +struct max_op { + __device__ __forceinline__ + DataType operator()(const T& x1, const T& x2) const { + return gpu_lib::max(x1, x2); + } +}; + +// ========================================================= +// Forward prop +// ========================================================= + +/** @brief Max reduction over last dimension of 3D tensor. + * + * Each CUDA block computes the max over a subset of tensor entries + * in @c vals and outputs the result to @c maxvals. This should be + * repeated multiple times to fully reduce the last tensor dimension. + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (vals_dims[2] / bdimx) x vals_dims[1] x vals_dims[0] + * + * maxvals: vals_dims[0] x vals_dims[1] x (vals_dims[2] / bdimx) + */ +template +__global__ void channelwise_softmax_fp_max_kernel( + Size3 vals_dims, + const TensorDataType* __restrict__ vals_buffer, + Size3 vals_strides, + TensorDataType* __restrict__ maxvals_buffer, + Size3 maxvals_strides) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t bidx = blockIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < vals_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < vals_dims[1]; j += nthreadsy) { + + // Find largest value for each thread + TensorDataType maxval{-gpu_lib::infinity()}; + for (size_t i = gidx; i < vals_dims[2]; i += nthreadsx) { + const auto& val = vals_buffer[k * vals_strides[0] + + j * vals_strides[1] + + i * vals_strides[2]]; + maxval = gpu_lib::max(maxval, val); + } + + // Find largest value for each block + maxval = gpu_lib::block_reduce>(maxval); + if (tid == 0) { + const auto& pos = (k * maxvals_strides[0] + + j * maxvals_strides[1] + + bidx * maxvals_strides[2]); + maxvals_buffer[pos] = maxval; + } + + } + } + +} + +/** Compute softmax denominator. + * + * denom = sum( exp(x_i-shift) ) + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (input_dims[2] / bdimx) x input_dims[1] x input_dims[0] + * + * shifts and denoms are fully-packed 2D tensors with dimensions of + * input_dims[0] x input_dims[1]. + */ +template +__global__ void channelwise_softmax_fp_denom_kernel( + Size3 input_dims, + const TensorDataType* __restrict__ input_buffer, + Size3 input_strides, + const TensorDataType* __restrict__ shifts, + TensorDataType* __restrict__ denoms) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { + + // Compute contribution from each thread + const auto& shift = shifts[j + k*input_dims[1]]; + TensorDataType denom{0.}; + for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { + const auto& x = input_buffer[k * input_strides[0] + + j * input_strides[1] + + i * input_strides[2]]; + denom += gpu_lib::exp(x-shift); + } + + // Compute contribution from each block + denom = gpu_lib::block_reduce(denom); + if (tid == 0) { + gpu_lib::atomic_add(&denoms[j+k*input_dims[1]], denom); + } + + } + } + +} + +/** Compute softmax. + * + * y_i = exp(x_i-shift) / denom + * + * Block dimensions: bdimx x bdimy x bdimz + * + * Grid dimensions: (input_dims[2] / bdimx) x (input_dims[1] / bdimy) x (input_dims[0] / bdimz) + * + * shifts and denoms are fully-packed 2D tensors with dimensions of + * input_dims[0] x input_dims[1]. + */ +template +__global__ void channelwise_softmax_fp_output_kernel( + Size3 input_dims, + const TensorDataType* __restrict__ input_buffer, + Size3 input_strides, + TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ shifts, + const TensorDataType* __restrict__ denoms) { + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + for (size_t k = gidz; k < input_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < input_dims[1]; j += nthreadsy) { + const auto& shift = shifts[j + k*input_dims[1]]; + const auto& denom = denoms[j + k*input_dims[1]]; + for (size_t i = gidx; i < input_dims[2]; i += nthreadsx) { + const auto& x = input_buffer[k * input_strides[0] + + j * input_strides[1] + + i * input_strides[2]]; + auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + y = gpu_lib::exp(x-shift) / denom; + } + } + } + +} + +/** @brief Forward prop */ +template +void channelwise_softmax_fp_impl(size_t num_channels, + size_t channel_size, + const El::Matrix& local_input, + El::Matrix& local_output) { + + // Local matrices + using LocalMat = El::Matrix; + + auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_output), + gpu::get_sync_info(local_input)); + + // Dimensions + const size_t local_mini_batch_size = local_input.Width(); + // const Size3 input_dims{local_mini_batch_size, num_channels, channel_size}; + + // Compute softmax shifts + LocalMat local_shifts; + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + LocalMat maxvals(grid_dims.x * num_channels, local_mini_batch_size); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_max_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + maxvals.Buffer(), + Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); + while (grid_dims.x > 1) { + const size_t prev_dim = grid_dims.x; + grid_dims.x = (prev_dim + block_size - 1) / block_size; + const LocalMat prev_maxvals(std::move(maxvals)); + maxvals.Resize(grid_dims.x * num_channels, local_mini_batch_size); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_max_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, prev_dim}, + prev_maxvals.LockedBuffer(), + Size3{static_cast(prev_maxvals.LDim()), prev_dim, 1}, + maxvals.Buffer(), + Size3{static_cast(maxvals.LDim()), grid_dims.x, 1}); + } + local_shifts = std::move(maxvals); + } + + // Compute softmax denominators + LocalMat local_denoms(num_channels, local_mini_batch_size); + El::Zero(local_denoms); + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_denom_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + local_shifts.LockedBuffer(), + local_denoms.Buffer()); + } + + // Compute softmax + if (!local_input.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_fp_output_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_input.LockedBuffer(), + Size3{static_cast(local_input.LDim()), channel_size, 1}, + local_output.Buffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_shifts.LockedBuffer(), + local_denoms.LockedBuffer()); + } + +} + +// ========================================================= +// Backprop +// ========================================================= + +/** Compute dot product between output and gradient w.r.t. output. + * + * Block dimensions: bdimx x 1 x 1 + * + * Grid dimensions: (output_dims[2] / bdimx) x output_dims[1] x output_dims[0] + * + * y_dot_dy is a fully-packed 2D tensor with dimensions of + * output_dims[0] x output_dims[1]. + */ +template +__global__ void channelwise_softmax_bp_y_dot_dy_kernel( + Size3 output_dims, + const TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ output_grad_buffer, + Size3 output_grad_strides, + TensorDataType* __restrict__ y_dot_dy) { + + // Indices and dimensions + constexpr size_t bdimy = 1; + constexpr size_t bdimz = 1; + const size_t tid = threadIdx.x; + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + + for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { + + // Compute contribution from each thread + TensorDataType _y_dot_dy{0.}; + for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { + const auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + const auto& dy = output_grad_buffer[k * output_grad_strides[0] + + j * output_grad_strides[1] + + i * output_grad_strides[2]]; + _y_dot_dy += y * dy; + } + + // Compute contribution from each block + _y_dot_dy = gpu_lib::block_reduce(_y_dot_dy); + if (tid == 0) { + gpu_lib::atomic_add(&y_dot_dy[j+k*output_dims[1]], _y_dot_dy); + } + + } + } + +} + +/** Compute gradient w.r.t. input. + * + * dL/dx_i = y_i * ( dL/dy_i - dot(y,dL/dy) ) + * + * Block dimensions: bdimx x bdimy x bdimz + * + * Grid dimensions: (output_dims[2] / bdimx) x (output_dims[1] / bdimy) x (output_dims[0] / bdimz) + * + * y_dot_dy is a fully-packed 2D tensor with dimensions of + * output_dims[0] x output_dims[1]. + */ +template +__global__ void channelwise_softmax_bp_input_grad_kernel( + Size3 output_dims, + const TensorDataType* __restrict__ output_buffer, + Size3 output_strides, + const TensorDataType* __restrict__ output_grad_buffer, + Size3 output_grad_strides, + TensorDataType* __restrict__ input_grad_buffer, + Size3 input_grad_strides, + const TensorDataType* __restrict__ y_dot_dy) { + + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + const size_t nthreadsx = blockDim.x * gridDim.x; + const size_t nthreadsy = blockDim.y * gridDim.y; + const size_t nthreadsz = blockDim.z * gridDim.z; + for (size_t k = gidz; k < output_dims[0]; k += nthreadsz) { + for (size_t j = gidy; j < output_dims[1]; j += nthreadsy) { + const auto& _y_dot_dy = y_dot_dy[j + k*output_dims[1]]; + for (size_t i = gidx; i < output_dims[2]; i += nthreadsx) { + const auto& y = output_buffer[k * output_strides[0] + + j * output_strides[1] + + i * output_strides[2]]; + const auto& dy = output_grad_buffer[k * output_grad_strides[0] + + j * output_grad_strides[1] + + i * output_grad_strides[2]]; + auto& dx = input_grad_buffer[k * input_grad_strides[0] + + j * input_grad_strides[1] + + i * input_grad_strides[2]]; + dx = y * (dy - _y_dot_dy); + } + } + } + +} + + +/** @brief Backprop */ +template +void channelwise_softmax_bp_impl(size_t num_channels, + size_t channel_size, + const El::Matrix& local_output, + const El::Matrix& local_output_grad, + El::Matrix& local_input_grad) { + + // Dimensions + const size_t local_mini_batch_size = local_output.Width(); + using LocalMat = El::Matrix; + // dot(y,dL/dy) + LocalMat local_y_dot_dy(num_channels, local_mini_batch_size); + El::Zero(local_y_dot_dy); + + auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_y_dot_dy), + gpu::get_sync_info(local_output_grad), + gpu::get_sync_info(local_output), + gpu::get_sync_info(local_input_grad)); + + if (!local_output.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_bp_y_dot_dy_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_output.LockedBuffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_output_grad.LockedBuffer(), + Size3{static_cast(local_output_grad.LDim()), channel_size, 1}, + local_y_dot_dy.Buffer()); + } + + // Compute gradient w.r.t. input + if (!local_output.IsEmpty()) { + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (channel_size + block_size - 1) / block_size; + grid_dims.y = num_channels; + grid_dims.z = local_mini_batch_size; + gpu_lib::clip_grid_dims(grid_dims); + hydrogen::gpu::LaunchKernel( + channelwise_softmax_bp_input_grad_kernel, + grid_dims, block_dims, 0, multisync, + Size3{local_mini_batch_size, num_channels, channel_size}, + local_output.LockedBuffer(), + Size3{static_cast(local_output.LDim()), channel_size, 1}, + local_output_grad.LockedBuffer(), + Size3{static_cast(local_output_grad.LDim()), channel_size, 1}, + local_input_grad.Buffer(), + Size3{static_cast(local_input_grad.LDim()), channel_size, 1}, + local_y_dot_dy.LockedBuffer()); + } + +} +} // namespace anonymous +} // namespace lbann +#endif // LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_KERNELS diff --git a/src/layers/misc/distconv/CMakeLists.txt b/src/layers/misc/distconv/CMakeLists.txt new file mode 100644 index 00000000000..a6c3851eb38 --- /dev/null +++ b/src/layers/misc/distconv/CMakeLists.txt @@ -0,0 +1,31 @@ +################################################################################ +## Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +## Produced at the Lawrence Livermore National Laboratory. +## Written by the LBANN Research Team (B. Van Essen, et al.) listed in +## the CONTRIBUTORS file. +## +## LLNL-CODE-697807. +## All rights reserved. +## +## This file is part of LBANN: Livermore Big Artificial Neural Network +## Toolkit. For details, see http://software.llnl.gov/LBANN or +## https://github.com/LLNL/LBANN. +## +## Licensed under the Apache License, Version 2.0 (the "Licensee"); you +## may not use this file except in compliance with the License. You may +## obtain a copy of the License at: +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +## implied. See the License for the specific language governing +## permissions and limitations under the license. +################################################################################ +set_full_path(THIS_DIR_CU_SOURCES + distconv_channelwise_softmax.cu + ) + +# Propagate the files up the tree +set(GPU_SOURCES "${GPU_SOURCES}" "${THIS_DIR_CU_SOURCES}" PARENT_SCOPE) diff --git a/src/layers/misc/distconv/distconv_channelwise_softmax.cu b/src/layers/misc/distconv/distconv_channelwise_softmax.cu new file mode 100644 index 00000000000..ba1f25c17ab --- /dev/null +++ b/src/layers/misc/distconv/distconv_channelwise_softmax.cu @@ -0,0 +1,143 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#define LBANN_LAYERS_MISC_CHANNELWISE_SOFTMAX_INSTANTIATE +#include "../channelwise_softmax_kernels.cuh" +#include "lbann/base.hpp" +#include "lbann/layers/misc/distconv/distconv_channelwise_softmax.hpp" +#include "lbann/utils/distconv.hpp" +#include "lbann/utils/gpu/helpers.hpp" + +#ifdef LBANN_HAS_DISTCONV +namespace distconv { +template +template +int ChannelwiseSoftmax::forward( + const tensor::Tensor& input_0, + tensor::Tensor& output) +{ + + if (input_0.get_local_size() == 0 || output.get_local_size() == 0) { + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; + return 1; // no op for empty inputs + } + + const auto& input_0_dims = input_0.get_local_shape(); + + const auto num_channels = input_0_dims[2]; + const auto local_mini_batch_size = input_0_dims[3]; + const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; + const auto mat_stride = num_channels * mat_channel_size; + + // Convert to Hydrogen matrices for kernel launch + + using LocalMat = El::Matrix; + + LocalMat local_input(mat_stride, + local_mini_batch_size, + input_0.get_buffer(), + mat_stride); + + LocalMat local_output(mat_stride, + local_mini_batch_size, + output.get_buffer(), + mat_stride); + + ::lbann::channelwise_softmax_fp_impl(num_channels, + mat_channel_size, + local_input, + local_output); + return 1; +} + +template +template +int ChannelwiseSoftmax::backward( + const tensor::Tensor& output, + const tensor::Tensor& output_grad, + tensor::Tensor& input_grad_0) +{ + if (output.get_local_size() == 0 || output_grad.get_local_size() == 0 || + input_grad_0.get_local_size() == 0) { + util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; + return 1; // no op for empty inputs + } + + const auto& input_0_dims = output.get_local_shape(); + const auto num_channels = input_0_dims[2]; + const auto local_mini_batch_size = input_0_dims[3]; + const auto mat_channel_size = input_0_dims[0] * input_0_dims[1]; + const auto mat_stride = num_channels * mat_channel_size; + + // Convert to Hydrogen matrices for kernel launch + + using LocalMat = El::Matrix; + + LocalMat local_output(mat_stride, + local_mini_batch_size, + output.get_buffer(), + mat_stride); + + LocalMat local_output_grad(mat_stride, + local_mini_batch_size, + output_grad.get_buffer(), + mat_stride); + + LocalMat local_input_grad(mat_stride, + local_mini_batch_size, + input_grad_0.get_buffer(), + mat_stride); + + ::lbann::channelwise_softmax_bp_impl(num_channels, + mat_channel_size, + local_output, + local_output_grad, + local_input_grad); + return 1; +} + +// ========================================================= +// Explicit template instantiation +// ========================================================= + +#define PROTO(T) \ + template class ChannelwiseSoftmax; \ + template int \ + ChannelwiseSoftmax::forward( \ + const tensor::Tensor& \ + input_0, \ + tensor::Tensor& output_0); \ + template int \ + ChannelwiseSoftmax::backward( \ + const tensor::Tensor& \ + input_0, \ + const tensor::Tensor& \ + input_1, \ + tensor::Tensor& output_grad); + +#include "lbann/macros/instantiate.hpp" +} // namespace distconv +#endif // LBANN_HAS_DISTCONV