diff --git a/src_sycl/LICENSE.md b/src_sycl/LICENSE.md new file mode 100644 index 0000000..dfdfa9d --- /dev/null +++ b/src_sycl/LICENSE.md @@ -0,0 +1,28 @@ +Modifications Copyright (C) 2023 Intel Corporation + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +SPDX-License-Identifier: BSD-3-Clause diff --git a/src_sycl/README.md b/src_sycl/README.md new file mode 100644 index 0000000..1acc6c5 --- /dev/null +++ b/src_sycl/README.md @@ -0,0 +1,81 @@ +# tsne + +tsne implements [FIt-SNE algorithm](https://github.com/KlugerLab/FIt-SNE) for various GPU architectures (original CUDA source code is from [here](https://github.com/CannyLab/tsne-cuda)). + +## SYCL version + +- The CUDA code was converted to SYCL using Intel's DPC++ Compatiblity Tool (DPCT) available [here](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html). +- The same SYCL code runs on Intel GPUs & CPUs as well as NVIDIA (tested on A100 and H100) and AMD (tested on MI100 and MI250) GPUs. See build instructions below for more details. +- NOTE #1: This version bypasses use of FAISS by running input images through an offline Python version of FAISS and using its output as input to this SYCL version. So this is more suitable for hardware and framework (SYCL, CUDA, HIP) benchmarking. +- NOTE #2: This version also does not use fft from MKL. Instead it uses a manually implemented fft. For apples-to-apples comparison, we do have a corresponding (modified) CUDA version available [here](https://github.com/oneapi-src/Velocity-Bench/tree/main/tsne) in [Velocity-Bench](https://github.com/oneapi-src/Velocity-Bench). I am happy to add that CUDA version here, if that will be useful. + +# Current Version: +- Initial release of the workload + +# Build Instructions +Notes +- icpx compiler mentioned below is included in the oneAPI Base Toolkit available [here](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html). +- clang++ compiler mentioned below is available [here](https://github.com/intel/llvm/blob/sycl/sycl/doc/GetStartedGuide.md). + + +For Intel GPU - +First source icpx compiler. Then, + +``` +cd src_sycl/SYCL +mkdir build +cd build +CXX=icpx cmake -DGPU_AOT=pvc .. +make -sj +``` +Note: +- To enable AOT compilation, please use the flag `-DGPU_AOT=pvc` for PVC. + +For AMD GPU - +First source clang++ compiler. Then, +``` +cd src_sycl/SYCL +mkdir build +cd build +CXX=clang++ cmake -DUSE_AMDHIP_BACKEND=gfx90a .. +make -sj +``` +Note: +- We use the flag `-DUSE_AMDHIP_BACKEND=gfx90a` for MI250. Use the correct value for your GPU. + +For NVIDIA GPU - +First source clang++ compiler. Then, +``` +cd src_sycl/SYCL +mkdir build +cd build +CXX=clang++ cmake -DUSE_NVIDIA_BACKEND=YES -DUSE_SM=80 .. +make -sj +``` +Note: +- We use the flag `-DUSE_SM=80` for A100 or `-DUSE_SM=90` for H100. + +# Run instructions + +After building, to run the workload, cd into the SYCL/build folder, if not already there. Then + +``` +# PVC 1 tile: +ONEAPI_DEVICE_SELECTOR=level_zero:0.0 ./tsne +``` +``` +# PVC 2 tiles: +ONEAPI_DEVICE_SELECTOR=level_zero:0 ./tsne +``` +``` +# AMD GPU: +ONEAPI_DEVICE_SELECTOR=hip:0 ./tsne +``` +``` +# NVIDIA GPU: +ONEAPI_DEVICE_SELECTOR=cuda:0 ./tsne +``` + +# Output + +Output gives the total time for running the whole workload. diff --git a/src_sycl/SYCL/CMakeLists.txt b/src_sycl/SYCL/CMakeLists.txt new file mode 100644 index 0000000..b2283b5 --- /dev/null +++ b/src_sycl/SYCL/CMakeLists.txt @@ -0,0 +1,131 @@ + # Modifications Copyright (C) 2023 Intel Corporation + # + # Redistribution and use in source and binary forms, with or without modification, + # are permitted provided that the following conditions are met: + # + # 1. Redistributions of source code must retain the above copyright notice, + # this list of conditions and the following disclaimer. + # 2. Redistributions in binary form must reproduce the above copyright notice, + # this list of conditions and the following disclaimer in the documentation + # and/or other materials provided with the distribution. + # 3. Neither the name of the copyright holder nor the names of its contributors + # may be used to endorse or promote products derived from this software + # without specific prior written permission. + # + # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS + # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT + # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + # + # + # SPDX-License-Identifier: BSD-3-Clause + # + +cmake_minimum_required(VERSION 3.10) + +project(tsne LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +option(ENABLE_KERNEL_PROFILING "Build using kernel profiling" OFF) +option(GPU_AOT "Build AOT for Intel GPU" OFF) +option(USE_NVIDIA_BACKEND "Build for NVIDIA backend" OFF) +option(USE_AMDHIP_BACKEND "Build for AMD HIP backend" OFF) + +if(ENABLE_KERNEL_PROFILING) + message("-- Enabling kernel profiling") + add_compile_options(-DENABLE_KERNEL_PROFILING) +endif() + +set(INTEL_GPU_CXX_FLAGS " -O2 -std=c++17 -fsycl -ffast-math -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare -Wno-unknown-pragmas -Wno-unused-local-typedef ") +set(NVIDIA_GPU_CXX_FLAGS " -O3 -std=c++17 -fsycl -ffast-math -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare -Wno-unknown-pragmas -Wno-unused-local-typedef ") +set(AMD_GPU_CXX_FLAGS " -O3 -std=c++17 -fsycl -ffast-math -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare -Wno-unknown-pragmas -Wno-unused-local-typedef ") + +set(USE_DEFAULT_FLAGS ON) +if("${CMAKE_CXX_FLAGS}" STREQUAL "") + message(STATUS "Using DEFAULT compilation flags") +else() + message(STATUS "OVERRIDING DEFAULT compilation flags") + set(USE_DEFAULT_FLAGS OFF) +endif() + +# JIT compilation +if(GPU_AOT) + message(STATUS "Enabling INTEL backend") + if(USE_DEFAULT_FLAGS) + set(CMAKE_CXX_FLAGS "${INTEL_GPU_CXX_FLAGS}") # Default flags for Intel backend + endif() + if( (${GPU_AOT} STREQUAL "pvc") OR (${GPU_AOT} STREQUAL "PVC") ) + message(STATUS "Enabling Intel GPU AOT compilation for ${GPU_AOT}") + string(APPEND CMAKE_CXX_FLAGS " -fsycl-targets=spir64_gen -Xs \"-device 0x0bd5 -revision_id 0x2f\" ") + else() + message(STATUS "Using custom AOT compilation flag ${GPU_AOT}") + string(APPEND CMAKE_CXX_FLAGS " ${GPU_AOT} ") # User should be aware of advanced AOT compilation flags + endif() +elseif(USE_NVIDIA_BACKEND) + message(STATUS "Enabling NVIDIA backend") + if(USE_DEFAULT_FLAGS) + set(CMAKE_CXX_FLAGS "${NVIDIA_GPU_CXX_FLAGS}") # Default flags for NV backend + endif() + string(APPEND CMAKE_CXX_FLAGS " -fsycl-targets=nvptx64-nvidia-cuda ") # -O3 will be used, even though -O2 was set earlier +elseif(USE_AMDHIP_BACKEND) + message(STATUS "Enabling AMD HIP backend for ${USE_AMDHIP_BACKEND} AMD architecture") + if(USE_DEFAULT_FLAGS) + set(CMAKE_CXX_FLAGS "${AMD_GPU_CXX_FLAGS}") # Default flags for AMD backend (gfx908 for MI100) + endif() + string(APPEND CMAKE_CXX_FLAGS " -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${USE_AMDHIP_BACKEND} ") +endif() + +if(GPU_AOT) + set(MKL_LINK static) + set(MKL_THREADING sequential) + find_package(MKL CONFIG REQUIRED HINTS "$ENV{MKLROOT}/lib/cmake/mkl") +endif() + +# Project Setup +#------------------------------------------------------------------------------- +set(SOURCES + # # Utils + ${CMAKE_SOURCE_DIR}/src/utils/debug_utils.dp.cpp + ${CMAKE_SOURCE_DIR}/src/utils/cuda_utils.dp.cpp + ${CMAKE_SOURCE_DIR}/src/utils/distance_utils.dp.cpp + ${CMAKE_SOURCE_DIR}/src/utils/math_utils.dp.cpp + ${CMAKE_SOURCE_DIR}/src/utils/matrix_broadcast_utils.dp.cpp + # ${CMAKE_SOURCE_DIR}/src/utils/reduce_utils.dp.cpp + + # # Kernels + ${CMAKE_SOURCE_DIR}/src/kernels/apply_forces.dp.cpp + ${CMAKE_SOURCE_DIR}/src/kernels/attr_forces.dp.cpp + ${CMAKE_SOURCE_DIR}/src/kernels/rep_forces.dp.cpp + ${CMAKE_SOURCE_DIR}/src/kernels/perplexity_search.dp.cpp + ${CMAKE_SOURCE_DIR}/src/kernels/nbodyfft.dp.cpp + + # Method files + ${CMAKE_SOURCE_DIR}/src/fit_tsne.dp.cpp + + ${CMAKE_SOURCE_DIR}/src/exe/main.dp.cpp +) + +include_directories( + ${CMAKE_SOURCE_DIR}/src/ + ${CMAKE_SOURCE_DIR}/src/include + /nfs/pdx/home/mgrabban/oneDPL/include + /nfs/pdx/home/mgrabban/oneTBB/include +) + +add_executable(tsne ${SOURCES}) + +if(GPU_AOT) + target_compile_options(tsne PUBLIC $) + target_include_directories(tsne PUBLIC $) + target_link_libraries(tsne PUBLIC $) +endif() diff --git a/src_sycl/SYCL/src/exe/main.dp.cpp b/src_sycl/SYCL/src/exe/main.dp.cpp new file mode 100644 index 0000000..3fc2fa1 --- /dev/null +++ b/src_sycl/SYCL/src/exe/main.dp.cpp @@ -0,0 +1,139 @@ +/* Modifications Copyright (C) 2023 Intel Corporation + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT + * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +// This file exposes a main file which does most of the testing with command line +// args, so we don't have to re-build to change options. + +// Detailed includes +#include +#include +#include +#include "include/fit_tsne.h" +#include "include/options.h" + +// Option parser +#include "include/cxxopts.hpp" + +#define TIMER_START() time_start = std::chrono::steady_clock::now(); +#define TIMER_END() \ + time_end = std::chrono::steady_clock::now(); \ + time_total = std::chrono::duration(time_end - time_start).count(); +#define TIMER_PRINT(name) std::cout << name <<": " << (time_total - time_total_) / 1e3 << " s\n"; + +// #ifndef DEBUG_TIME +// #define DEBUG_TIME +// #endif + +#define STRINGIFY(X) #X + +#define FOPT(x) result[STRINGIFY(x)].as() +#define SOPT(x) result[STRINGIFY(x)].as() +#define IOPT(x) result[STRINGIFY(x)].as() +#define BOPT(x) result[STRINGIFY(x)].as() + +int main(int argc, char** argv) +{ + std::chrono::steady_clock::time_point time_start; + std::chrono::steady_clock::time_point time_end; + double time_total = 0.0; + double time_total_ = 0.0; + + TIMER_START() + + try { + // Setup command line options + cxxopts::Options options("TSNE-CUDA","Perform T-SNE in an optimized manner."); + options.add_options() + ("l,learning-rate", "Learning Rate", cxxopts::value()->default_value("200")) + ("p,perplexity", "Perplexity", cxxopts::value()->default_value("50.0")) + ("e,early-ex", "Early Exaggeration Factor", cxxopts::value()->default_value("12.0")) + ("s,data", "Which program to run on ", cxxopts::value()->default_value("sim")) + ("k,num-points", "How many simulated points to use", cxxopts::value()->default_value("60000")) + ("u,nearest-neighbors", "How many nearest neighbors should we use", cxxopts::value()->default_value("32")) + ("n,num-steps", "How many steps to take", cxxopts::value()->default_value("1000")) + ("i,viz", "Use interactive visualization", cxxopts::value()->default_value("false")) + ("d,dump", "Dump the output points", cxxopts::value()->default_value("false")) + ("m,magnitude-factor", "Magnitude factor for KNN", cxxopts::value()->default_value("5.0")) + ("t,init", "What kind of initialization to use ", cxxopts::value()->default_value("gauss")) + ("f,fname", "File name for loaded data...", cxxopts::value()->default_value("../train-images.idx3-ubyte")) + ("c,connection", "Address for connection to vis server", cxxopts::value()->default_value("tcp://localhost:5556")) + ("q,dim", "Point Dimensions", cxxopts::value()->default_value("50")) + ("j,device", "Device to run on", cxxopts::value()->default_value("0")) + ("h,help", "Print help"); + + // Parse command line options + auto result = options.parse(argc, argv); + + if (result.count("help")) + { + std::cout << options.help({""}) << std::endl; + exit(0); + } + + tsnecuda::TSNE_INIT init_type = tsnecuda::TSNE_INIT::UNIFORM; + if (SOPT(init).compare("unif") == 0) { + init_type = tsnecuda::TSNE_INIT::UNIFORM; + } else { + init_type = tsnecuda::TSNE_INIT::GAUSSIAN; + } + + // Do the T-SNE + printf("Starting TSNE calculation with %u points.\n", IOPT(num-points)); + + // Construct the options + tsnecuda::Options opt(nullptr, IOPT(num-points), IOPT(dim), nullptr); + opt.perplexity = FOPT(perplexity); + opt.learning_rate = FOPT(learning-rate); + opt.early_exaggeration = FOPT(early-ex); + opt.iterations = IOPT(num-steps); + opt.iterations_no_progress = IOPT(num-steps); + opt.magnitude_factor = FOPT(magnitude-factor); + opt.num_neighbors = IOPT(nearest-neighbors); + opt.initialization = init_type; + + if (BOPT(dump)) { + opt.enable_dump("dump_ys.txt", 1); + } + if (BOPT(viz)) { + opt.enable_viz(SOPT(connection)); + } + + // Do the t-SNE + time_total_ = tsnecuda::RunTsne(opt); + std::cout << "\nDone!\n"; + } catch (std::exception const& e) { + std::cout << "Exception: " << e.what() << "\n"; + } + + TIMER_END() + TIMER_PRINT("tsne - total time for whole calculation") + + return 0; +} diff --git a/src_sycl/SYCL/src/fit_tsne.dp.cpp b/src_sycl/SYCL/src/fit_tsne.dp.cpp new file mode 100644 index 0000000..5f35c80 --- /dev/null +++ b/src_sycl/SYCL/src/fit_tsne.dp.cpp @@ -0,0 +1,770 @@ +/* Modifications Copyright (C) 2023 Intel Corporation + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT + * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +/* + Compute t-SNE via Barnes-Hut for NlogN time. +*/ + +#include +#include +#include + +#include +#include +#include +#include "include/fit_tsne.h" + +// #ifndef DEBUG_TIME +// #define DEBUG_TIME +// #endif + +#define TIMER_START_() time_start_ = std::chrono::steady_clock::now(); +#define TIMER_END_() \ + time_end_ = std::chrono::steady_clock::now(); \ + time_total_ += std::chrono::duration(time_end_ - time_start_).count(); +#define TIMER_PRINT_(name) std::cout << name <<": " << time_total_ / 1e3 << " s\n"; + +#ifdef DEBUG_TIME +// #define TIME_START() time_start = std::chrono::steady_clock::now(); +// #define TIME_SINCE(x) \ +// std::cout << "\nTime passed: " \ +// << std::chrono::duration_cast( \ +// std::chrono::steady_clock::now() - x).count() \ +// << " ms\n"; + +#define START_IL_TIMER() start = std::chrono::steady_clock::now(); +#define END_IL_TIMER(x) \ + stop = std::chrono::steady_clock::now(); \ + duration = std::chrono::duration_cast(stop - start); \ + x += duration; \ + total_time += duration; +#define PRINT_IL_TIMER(x) std::cout << #x << ": " << ((float)x.count()) / 1000000.0 << "s" << std::endl +#endif + +double tsnecuda::RunTsne(tsnecuda::Options& opt) +{ + std::chrono::steady_clock::time_point time_start_; + std::chrono::steady_clock::time_point time_end_; + double time_total_ = 0.0; + +#ifdef DEBUG_TIME + auto start = std::chrono::steady_clock::now(); + auto stop = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast(stop - start); + + auto total_time = duration; + auto _time_initialization = duration; + auto _time_knn = duration; + auto _time_symmetry = duration; + auto _time_init_low_dim = duration; + auto _time_init_fft = duration; + auto _time_precompute_2d = duration; + auto _time_nbodyfft = duration; + auto _time_compute_charges = duration; + auto _time_other = duration; + auto _time_repl = duration; + auto _time_attr = duration; + auto _time_apply_forces = duration; +#endif + + // auto time_start = std::chrono::steady_clock::now(); + + // Check the validity of the options file + if (!opt.validate()) { + std::cout << "E: Invalid options file. Terminating." << std::endl; + return 0.0; + } + + if (opt.verbosity > 0) { + std::cout << "Initializing sycl handles... " << std::flush; + } + + // Construct the handles + // TODO: Move this outside of the timing code, since RAPIDs is cheating by pre-initializing the handle. + // TODO: Allow for multi-stream on the computation, since we can overlap portions of our computation to be quicker. + + // // Setup some return information if we're working on snapshots + // // TODO: Add compile flag to remove snapshotting for timing parity + // int snap_num = 0; + // int snap_interval = 1; + // if (opt.return_style == tsnecuda::RETURN_STYLE::SNAPSHOT) + // { + // snap_interval = opt.iterations / (opt.num_snapshots - 1); + // } + + // Get constants from options + const int num_points = opt.num_points; // number of images + const int high_dim = opt.num_dims; // number of pixels per image + + // TODO: Warn if the number of neighbors is more than the number of points + const int num_neighbors = (opt.num_neighbors < num_points) ? opt.num_neighbors : num_points; + const float perplexity = opt.perplexity; + const float perplexity_search_epsilon = opt.perplexity_search_epsilon; + const float eta = opt.learning_rate; + float momentum = opt.pre_exaggeration_momentum; + float attr_exaggeration = opt.early_exaggeration; + float normalization = 0.0f; + + // Allocate host memory + // TODO: Pre-determine GPU/CPU memory requirements, since we will know them ahead of time, and can estimate + // if you're going to run out of GPU memory + // TODO: Investigate what it takes to use unified memory + Async fetch and execution + long* knn_indices = new long[ num_points * num_neighbors]; + float* knn_distances = new float[num_points * num_neighbors]; + memset(knn_distances, 0.0f, num_points * num_neighbors * sizeof(float)); + + if (opt.verbosity > 0) { + std::cout << "done.\nKNN Load...\n" << std::flush; + } + + TIMER_START_() + // Compute approximate K Nearest Neighbors and squared distances + // TODO: See if we can gain some time here by updating FAISS, and building better indicies + // TODO: Add suport for arbitrary metrics on GPU (Introduced by recent FAISS computation) + // TODO: Expose Multi-GPU computation (+ Add streaming memory support for GPU optimization) + std::string data_folder = "../../data/mnist_faissed/"; + // std::string data_folder = "../../data/cifar10_faissed/"; + tsnecuda::utils::KNearestNeighbors( + std::move(data_folder), // folder containing input files + knn_indices, // *** output indices *** + knn_distances, // *** output distances *** + high_dim, // number of pixels per image = 784 + num_points, // number of images + num_neighbors); + TIMER_END_() + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + sycl::device dts = sycl::device(sycl::default_selector_v); + + sycl::queue qts(dts); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_initialization); +#endif + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + auto knn_indices_device = sycl::malloc_device(num_points * num_neighbors, qts); + auto pij_indices_device = sycl::malloc_device(num_points * num_neighbors, qts); + qts.memcpy(knn_indices_device, knn_indices, num_points * num_neighbors * sizeof(long)); + qts.wait_and_throw(); + + tsnecuda::utils::PostprocessNeighborIndices( // pij_indices_device[i] = (int)knn_indices_device + pij_indices_device, // output + knn_indices_device, // input + num_points, + num_neighbors, + qts); + + // Max-norm the distances to avoid exponentiating by large numbers + auto knn_distances_device = sycl::malloc_device(num_points * num_neighbors, qts); + qts.memcpy(knn_distances_device, knn_distances, num_points * num_neighbors * sizeof(float)); + qts.wait_and_throw(); + tsnecuda::utils::MaxNormalizeDeviceVector( // divide by max abs value + knn_distances_device, // input and output + num_points * num_neighbors, + qts); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_knn); +#endif + + if (opt.verbosity > 0) { + std::cout << "done.\nComputing Pij matrix... " << std::endl; + } + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // Search Perplexity + auto pij_nonsymmetric_device = sycl::malloc_device(num_points * num_neighbors, qts); + tsnecuda::SearchPerplexity( + pij_nonsymmetric_device, // output + knn_distances_device, // input + perplexity, + perplexity_search_epsilon, + num_points, + num_neighbors, + qts); + qts.wait_and_throw(); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_symmetry); +#endif + + // Clean up memory + sycl::free(knn_indices_device, qts); + sycl::free(knn_distances_device, qts); + delete[] knn_indices; + delete[] knn_distances; + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // Symmetrize the pij matrix + auto pij_symmetric_device = sycl::malloc_device(num_points * num_neighbors, qts); + tsnecuda::utils::SymmetrizeMatrixV2( + pij_symmetric_device, // output + pij_nonsymmetric_device, // input + pij_indices_device, // input + num_points, + num_neighbors, + qts); + qts.wait_and_throw(); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_symmetry); +#endif + + // Clean up memory + sycl::free(pij_nonsymmetric_device, qts); + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // Declare memory + auto attractive_forces_device = sycl::malloc_device(opt.num_points * 2, qts); + auto repulsive_forces_device = sycl::malloc_device(opt.num_points * 2, qts); + auto gains_device = sycl::malloc_device(opt.num_points * 2, qts); + auto old_forces_device = sycl::malloc_device(opt.num_points * 2, qts); + auto ones_device = sycl::malloc_device(opt.num_points * 2, qts); + auto normalization_vec_device = sycl::malloc_device(opt.num_points, qts); + auto pij_workspace_device = sycl::malloc_device(num_points * num_neighbors * 2, qts); + + qts.fill(gains_device, 1.0f, opt.num_points * 2); + qts.fill(old_forces_device, 0.0f, opt.num_points * 2); + qts.fill(ones_device, 1.0f, opt.num_points * 2); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_symmetry); +#endif + + if (opt.verbosity > 0) { + std::cout << "done.\nInitializing low dim points... " << std::flush; + } + + auto points_host = sycl::malloc_host(num_points * 2, qts); + + TIMER_START_() + if (opt.initialization == tsnecuda::TSNE_INIT::GAUSSIAN) { // Random gaussian initialization + std::ifstream points_file; + points_file.open("../../data/points.txt"); + if (!points_file) std::cerr << "Can't open points.txt!"; + if (points_file.is_open()) { + int i = 0; + while (!points_file.eof()) { + points_file >> points_host[i++]; + } + points_file.close(); + i--; + if (i != num_points * 2) { + std::cout << "Number of data points read: " << i << std::endl; + std::cout << "That is incorrect\n"; + exit(1); + } + } else { + std::cout << "Can't read points file\n"; + exit(1); + } + } else { // Invalid initialization + std::cerr << "E: Invalid initialization type specified." << std::endl; + exit(1); + } + TIMER_END_() + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // Initialize Low-Dim Points + auto points_device = sycl::malloc_device(num_points * 2, qts); + qts.memcpy(points_device, points_host, num_points * 2 * sizeof(float)); + qts.wait(); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_init_low_dim); +#endif + + sycl::free(points_host, qts); + + if (opt.verbosity > 0) { + std::cout << "done.\nInitializing SYCL memory... " << std::flush; + } + + // FIT-TNSE Parameters + int n_terms = 4; + int n_interp_points = 3; + int n_boxes_per_dim = 125; + + // FFTW works faster on numbers that can be written as 2^a 3^b 5^c 7^d + // 11^e 13^f, where e+f is either 0 or 1, and the other exponents are + // arbitrary + int allowed_n_boxes_per_dim[21] = {25, 36, 50, 55, 60, 65, 70, 75, 80, 85, 90, 96, 100, 110, 120, 130, 140, 150, 175, 200, 1125}; + if (n_boxes_per_dim < allowed_n_boxes_per_dim[20]) { //Round up to nearest grid point + int chosen_i; + for (chosen_i = 0; allowed_n_boxes_per_dim[chosen_i] < n_boxes_per_dim; chosen_i++) + ; + n_boxes_per_dim = allowed_n_boxes_per_dim[chosen_i]; + } + + int n_total_boxes = n_boxes_per_dim * n_boxes_per_dim; + int total_interp_points = n_interp_points * n_interp_points * n_total_boxes; + int n_fft_coeffs_half = n_interp_points * n_boxes_per_dim; + int n_fft_coeffs = n_interp_points * n_boxes_per_dim * 2; + // int n_interp_points_1d = n_interp_points * n_boxes_per_dim; + int N = num_points; + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // FIT-TSNE Device Vectors + auto point_box_idx_device = sycl::malloc_device(N, qts); + auto x_in_box_device = sycl::malloc_device(N, qts); + auto y_in_box_device = sycl::malloc_device(N, qts); + auto y_tilde_values = sycl::malloc_device(total_interp_points * n_terms, qts); + auto x_interpolated_values_device = sycl::malloc_device(N * n_interp_points, qts); + auto y_interpolated_values_device = sycl::malloc_device(N * n_interp_points, qts); + auto potentialsQij_device = sycl::malloc_device(N * n_terms, qts); + // auto all_interpolated_values_device = sycl::malloc_device(n_terms * n_interp_points * n_interp_points * N, qts); + // auto output_values = sycl::malloc_device(n_terms * n_interp_points * n_interp_points * N, qts); + // auto all_interpolated_indices = sycl::malloc_device(n_terms * n_interp_points * n_interp_points * N, qts); + // auto output_indices = sycl::malloc_device(n_terms * n_interp_points * n_interp_points * N, qts); + auto w_coefficients_device = sycl::malloc_device(total_interp_points * n_terms, qts); + auto chargesQij_device = sycl::malloc_device(N * n_terms, qts); + auto box_lower_bounds_device = sycl::malloc_device(2 * n_total_boxes, qts); + auto box_upper_bounds_device = sycl::malloc_device(2 * n_total_boxes, qts); + + auto kernel_tilde_device = sycl::malloc_device( n_fft_coeffs * n_fft_coeffs, qts); + auto fft_kernel_tilde_device = sycl::malloc_device>(n_fft_coeffs * n_fft_coeffs, qts); + + auto fft_scratchpad_device = sycl::malloc_device>(n_fft_coeffs * n_fft_coeffs * n_terms, qts); // added + + auto fft_input = sycl::malloc_device( n_fft_coeffs * n_fft_coeffs * n_terms, qts); + auto fft_w_coefficients = sycl::malloc_device>(n_fft_coeffs * (n_fft_coeffs/2+1) * n_terms, qts); + auto fft_output = sycl::malloc_device( n_fft_coeffs * n_fft_coeffs * n_terms, qts); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_init_fft); +#endif + + // Easier to compute denominator on CPU, so we should just calculate y_tilde_spacing on CPU also + float h = 1 / (float)n_interp_points; + float y_tilde_spacings[n_interp_points]; + y_tilde_spacings[0] = h / 2; + for (int i = 1; i < n_interp_points; i++) { + y_tilde_spacings[i] = y_tilde_spacings[i - 1] + h; + } + float denominator[n_interp_points]; + for (int i = 0; i < n_interp_points; i++) { + denominator[i] = 1; + for (int j = 0; j < n_interp_points; j++) { + if (i != j) { + denominator[i] *= y_tilde_spacings[i] - y_tilde_spacings[j]; + } + } + } + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + auto y_tilde_spacings_device = sycl::malloc_device(n_interp_points, qts); + auto denominator_device = sycl::malloc_device(n_interp_points, qts); + qts.memcpy(y_tilde_spacings_device, y_tilde_spacings, n_interp_points * sizeof(float)); + qts.memcpy(denominator_device, denominator , n_interp_points * sizeof(float)); + qts.fill(fft_input, 0.0f, n_fft_coeffs * n_fft_coeffs * n_terms); + + auto policy = oneapi::dpl::execution::make_device_policy(qts); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_init_fft); +#endif + + if (opt.verbosity > 0) { + std::cout << "done." << std::endl; + } + + // int fft_dimensions[2] = {n_fft_coeffs, n_fft_coeffs}; // {780, 780} + + // std::int64_t fwd_strides1[3] = {0, n_fft_coeffs, 1}; // {0, 780, 1} -> 0 + 780*i + j + // std::int64_t fwd_strides2[3] = {0, (n_fft_coeffs/2+1)*2, 1}; // {0, 780, 1} -> 0 + 780*i + j + // std::int64_t bwd_strides[3] = {0, (n_fft_coeffs/2+1), 1}; // {0, 391, 1} -> 0 + 391*i + j + // std::int64_t fwd_distances1 = n_fft_coeffs* n_fft_coeffs; + // std::int64_t fwd_distances2 = n_fft_coeffs*(n_fft_coeffs/2+1)*2; + // std::int64_t bwd_distances = n_fft_coeffs*(n_fft_coeffs/2+1) ; + + // std::cout << "Setting up dft plans...\n"; + // // *** TIMED SEPARATELY. NOT ADDED TO PERF TIME *** + // TIME_START(); + // std::shared_ptr plan_tilde; + // plan_tilde = std::make_shared(std::vector{n_fft_coeffs, n_fft_coeffs}); + // plan_tilde->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + // plan_tilde->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, fwd_strides2); + // plan_tilde->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, bwd_strides); + // // plan_tilde->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, fwd_distances2); + // // plan_tilde->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, bwd_distances); + // plan_tilde->commit(qts); + // TIME_SINCE(time_start); + + // TIME_START(); + // std::shared_ptr plan_dft; + // plan_dft = std::make_shared(std::vector{n_fft_coeffs, n_fft_coeffs}); + // plan_dft->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + // plan_dft->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, fwd_strides1); + // plan_dft->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, bwd_strides); + // plan_dft->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, fwd_distances1); + // plan_dft->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, bwd_distances); + // plan_dft->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, n_terms); + // plan_dft->commit(qts); + // TIME_SINCE(time_start); + + // TIME_START(); + // std::shared_ptr plan_idft; + // plan_idft = std::make_shared(std::vector{n_fft_coeffs, n_fft_coeffs}); + // plan_idft->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + // plan_idft->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, bwd_strides); + // plan_idft->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, fwd_strides2); + // plan_idft->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, fwd_distances2); + // plan_idft->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, bwd_distances); + // plan_idft->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, n_terms); + // plan_idft->commit(qts); + // // *** TIMED SEPARATELY. NOT ADDED TO PERF TIME *** + // TIME_SINCE(time_start); + // std::cout << "done.\n"; + + // std::shared_ptr plan_tilde; + // std::shared_ptr plan_dft; + // std::shared_ptr plan_idft; + double duration_fft1 = 0.0, duration_fft2 = 0.0; + + // Support for infinite iteration + for (size_t step = 0; step != opt.iterations; step++) { + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + // TODO: We might be able to write a kernel which does this more efficiently. It probably doesn't require much + // TODO: but it could be done. + qts.fill(w_coefficients_device, 0.0f, total_interp_points * n_terms); // needs to be initialized + qts.fill(potentialsQij_device, 0.0f, N * n_terms); // needs to be initialized + qts.wait(); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_other); +#endif + + // Setup learning rate schedule + if (step == opt.force_magnify_iters) { + momentum = opt.post_exaggeration_momentum; + attr_exaggeration = 1.0f; + } + + // Prepare the terms that we'll use to compute the sum i.e. the repulsive forces +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + tsnecuda::ComputeChargesQij( + chargesQij_device, // output + points_device, // input + num_points, + n_terms, + qts); +#ifdef DEBUG_TIME + END_IL_TIMER(_time_compute_charges); +#endif + + // Compute Minimax elements +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + float min_coord = std::reduce( + policy, + points_device, + points_device + num_points * 2, + 0.0f, + oneapi::dpl::minimum()); + float max_coord = std::reduce( + policy, + points_device, + points_device + num_points * 2, + 0.0f, + oneapi::dpl::maximum()); +#ifdef DEBUG_TIME + END_IL_TIMER(_time_precompute_2d); +#endif + + float box_width = (max_coord - min_coord) / (float)n_boxes_per_dim; + if (step < 30) { + std::cout << "step: " << step << " min_coord: " << min_coord << " max_coord: " << max_coord << std::endl; + std::cout << " box_width: " << box_width << std::endl; + } + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // Compute the number of boxes in a single dimension and the total number of boxes in 2d + tsnecuda::PrecomputeFFT2D( + // plan_tilde, + max_coord, // input + min_coord, // input + max_coord, // input + min_coord, // input + n_boxes_per_dim, // 130 + n_interp_points, // 3 + box_lower_bounds_device, // output: 2 * n_total_boxes size buffer [where n_total_boxes = 130 x 130] + box_upper_bounds_device, // output: 2 * n_total_boxes size buffer [where n_total_boxes = 130 x 130] + kernel_tilde_device, // output?: n_fft_coeffs * n_fft_coeffs size buffer [n_fft_coeffs = 2 x 3 x 130] + fft_kernel_tilde_device, // output: n_fft_coeffs * n_fft_coeffs size buffer + fft_scratchpad_device, + qts, duration_fft1); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_precompute_2d); +#endif + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + tsnecuda::NbodyFFT2D( + // plan_dft, + // plan_idft, + fft_kernel_tilde_device, // input + fft_w_coefficients, // intermediate value + N, + n_terms, + n_boxes_per_dim, + n_interp_points, + n_total_boxes, + total_interp_points, + min_coord, + box_width, + n_fft_coeffs_half, + n_fft_coeffs, + fft_input, // intermediate value + fft_output, // intermediate value + point_box_idx_device, // intermediate value + x_in_box_device, // intermediate value + y_in_box_device, // intermediate value + points_device, // input + box_lower_bounds_device, // input + y_tilde_spacings_device, // input (calculated outside the loop) + denominator_device, // input (calculated outside the loop) + y_tilde_values, // intermediate value + // all_interpolated_values_device, + // output_values, + // all_interpolated_indices, + // output_indices, + w_coefficients_device, + chargesQij_device, // input + x_interpolated_values_device, // intermediate value + y_interpolated_values_device, // intermediate value + potentialsQij_device, // intermediate value + fft_scratchpad_device, + qts, duration_fft2); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_nbodyfft); +#endif + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // TODO: We can overlap the computation of the attractive and repulsive forces, this requires changing the + // TODO: default streams of the code in both of these methods + // TODO: See: https://stackoverflow.com/questions/24368197/getting-cuda-thrust-to-use-a-cuda-stream-of-your-choice + // Make the negative term, or F_rep in the equation 3 of the paper + + // Calculate Repulsive Forces + normalization = tsnecuda::ComputeRepulsiveForces( + repulsive_forces_device, // num_points * 2 (output: uninitialized) + normalization_vec_device, // num_points (output: uninitialized) + points_device, // num_points * 2 (input: initially randomly generated) + potentialsQij_device, // N * n_terms = num_points * 4 (input) + num_points, + n_terms, + qts); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_repl); +#endif + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // Calculate Attractive Forces + tsnecuda::ComputeAttractiveForcesV3( + attractive_forces_device, // num_points * 2 (output: uninitialized) + pij_symmetric_device, // num_points * num_neighbors (input: calculated using SymmetrizeMatrixV2) + pij_indices_device, // num_points * num_neighbors (input: calculated using PostprocessNeighborIndices) + pij_workspace_device, // num_points * num_neighbors * 2 (output) + points_device, // num_points * 2 (input: initially randomly generated) + ones_device, // num_points * 2 (input: all 1.0f) + num_points, + num_neighbors, + qts); + + qts.wait(); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_attr); +#endif + +#ifdef DEBUG_TIME + START_IL_TIMER(); +#endif + + // TODO: Add stream synchronization here. + + // Apply Forces + tsnecuda::ApplyForces( + points_device, // num_points * 2 (input/output: initially randomly generated) + attractive_forces_device, // num_points * 2 (input/output: calculated using ComputeAttractiveForcesV3) + repulsive_forces_device, // num_points * 2 (input/output: calculated using ComputeRepulsiveForces) + gains_device, // num_points * 2 (input/output: all 1.0f) + old_forces_device, // num_points * 2 (input/output: all 0.0f) + eta, // scalar (learning rate) + normalization, // scalar (return value of ComputeRepulsiveForces) + momentum, + attr_exaggeration, + num_points, + qts); + + // Compute the gradient norm + float grad_norm = tsnecuda::utils::L2NormDeviceVector( + old_forces_device, + opt.num_points * 2, + qts); + +#ifdef DEBUG_TIME + END_IL_TIMER(_time_apply_forces); +#endif + + if (grad_norm < opt.min_gradient_norm) { + if (opt.verbosity >= 1) { + std::cout << "Reached minimum gradient norm: " << grad_norm << std::endl; + } + break; + } + + if (opt.verbosity >= 1 && step % opt.print_interval == 0) { + std::cout << "[Step " << step << "] Avg. Gradient Norm: " << grad_norm << std::endl; + } + } // End for loop + // std::cout << "DFT2D1gpu : duration_fft2: " << duration_fft2 << " ms" << std::endl; +#ifdef DEBUG_TIME + if (opt.verbosity > 0) { + PRINT_IL_TIMER(_time_initialization); + PRINT_IL_TIMER(_time_knn); + PRINT_IL_TIMER(_time_symmetry); + PRINT_IL_TIMER(_time_init_low_dim); + PRINT_IL_TIMER(_time_init_fft); + PRINT_IL_TIMER(_time_compute_charges); + PRINT_IL_TIMER(_time_precompute_2d); + PRINT_IL_TIMER(_time_nbodyfft); + PRINT_IL_TIMER(_time_repl); + PRINT_IL_TIMER(_time_attr); + PRINT_IL_TIMER(_time_apply_forces); + PRINT_IL_TIMER(_time_other); + // PRINT_IL_TIMER(total_time); + } +#endif + + // Write output - not timed + if (opt.get_dump_points()) { + auto host_ys = sycl::malloc_host(num_points * 2, qts); + qts.memcpy(host_ys, points_device, num_points * 2 * sizeof(float)).wait(); + + TIMER_START_() + std::ofstream dump_file; + dump_file.open(opt.get_dump_file()); + dump_file << num_points << " " << 2 << std::endl; + + for (int i = 0; i < opt.num_points; i++) { + dump_file << host_ys[i] << " " << host_ys[i + num_points] << std::endl; + } + dump_file.close(); + TIMER_END_() + + sycl::free(host_ys, qts); + } + + // Return some final values + opt.trained = true; + opt.trained_norm = normalization; + + sycl::free(pij_indices_device, qts); + sycl::free(pij_symmetric_device, qts); + sycl::free(attractive_forces_device, qts); + sycl::free(repulsive_forces_device, qts); + sycl::free(gains_device, qts); + sycl::free(old_forces_device, qts); + sycl::free(ones_device, qts); + sycl::free(normalization_vec_device, qts); + sycl::free(pij_workspace_device, qts); + sycl::free(points_device, qts); + sycl::free(point_box_idx_device, qts); + sycl::free(x_in_box_device, qts); + sycl::free(y_in_box_device, qts); + sycl::free(y_tilde_values, qts); + sycl::free(x_interpolated_values_device, qts); + sycl::free(y_interpolated_values_device, qts); + sycl::free(potentialsQij_device, qts); + // sycl::free(all_interpolated_values_device, qts); + // sycl::free(output_values, qts); + // sycl::free(all_interpolated_indices, qts); + // sycl::free(output_indices, qts); + sycl::free(w_coefficients_device, qts); + sycl::free(chargesQij_device, qts); + sycl::free(box_lower_bounds_device, qts); + sycl::free(box_upper_bounds_device, qts); + sycl::free(kernel_tilde_device, qts); + sycl::free(fft_input, qts); + sycl::free(fft_output, qts); + sycl::free(fft_w_coefficients, qts); + sycl::free(fft_kernel_tilde_device, qts); + sycl::free(y_tilde_spacings_device, qts); + sycl::free(denominator_device, qts); + + TIMER_PRINT_("time to subtract from total") + return time_total_; +} diff --git a/src_sycl/SYCL/src/include/common.h b/src_sycl/SYCL/src/include/common.h new file mode 100644 index 0000000..5983807 --- /dev/null +++ b/src_sycl/SYCL/src/include/common.h @@ -0,0 +1,71 @@ +/* Modifications Copyright (C) 2023 Intel Corporation + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * 3. Neither the name of the copyright holder nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT + * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +/** + * @brief Common includes for the Cuda-TSNE project + * + * @file common.h + * @author David Chan + * @date 2018-04-04 + */ + + #ifndef COMMON_H + #define COMMON_H + + // SYCL Includes +#include +// #include + +// typedef oneapi::mkl::dft::descriptor descriptor_t; + +// Thrust includes + +// C Library includes +#include +#include +#include +#include +#include +#include +#include +#include + +// C++ Library includes +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#endif diff --git a/src_sycl/SYCL/src/include/cxxopts.hpp b/src_sycl/SYCL/src/include/cxxopts.hpp new file mode 100644 index 0000000..966833c --- /dev/null +++ b/src_sycl/SYCL/src/include/cxxopts.hpp @@ -0,0 +1,2710 @@ +/* + +Copyright (c) 2014, 2015, 2016, 2017 Jarryd Beck + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +*/ + +#ifndef CXXOPTS_HPP_INCLUDED +#define CXXOPTS_HPP_INCLUDED + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__GNUC__) && !defined(__clang__) +# if (__GNUC__ * 10 + __GNUC_MINOR__) < 49 +# define CXXOPTS_NO_REGEX true +# endif +#endif + +#ifndef CXXOPTS_NO_REGEX +# include +#endif // CXXOPTS_NO_REGEX + +// Nonstandard before C++17, which is coincidentally what we also need for +#ifdef __has_include +# if __has_include() +#include +#include +# ifdef __cpp_lib_optional +# define CXXOPTS_HAS_OPTIONAL +# endif +# endif +#endif + +#if __cplusplus >= 201603L +#define CXXOPTS_NODISCARD [[nodiscard]] +#else +#define CXXOPTS_NODISCARD +#endif + +#ifndef CXXOPTS_VECTOR_DELIMITER +#define CXXOPTS_VECTOR_DELIMITER ',' +#endif + +#define CXXOPTS__VERSION_MAJOR 3 +#define CXXOPTS__VERSION_MINOR 0 +#define CXXOPTS__VERSION_PATCH 0 + +#if (__GNUC__ < 10 || (__GNUC__ == 10 && __GNUC_MINOR__ < 1)) && __GNUC__ >= 6 + #define CXXOPTS_NULL_DEREF_IGNORE +#endif + +namespace cxxopts +{ + static constexpr struct { + uint8_t major, minor, patch; + } version = { + CXXOPTS__VERSION_MAJOR, + CXXOPTS__VERSION_MINOR, + CXXOPTS__VERSION_PATCH + }; +} // namespace cxxopts + +//when we ask cxxopts to use Unicode, help strings are processed using ICU, +//which results in the correct lengths being computed for strings when they +//are formatted for the help output +//it is necessary to make sure that can be found by the +//compiler, and that icu-uc is linked in to the binary. + +#ifdef CXXOPTS_USE_UNICODE +#include + +namespace cxxopts +{ + using String = icu::UnicodeString; + + inline + String + toLocalString(std::string s) + { + return icu::UnicodeString::fromUTF8(std::move(s)); + } + +#if defined(__GNUC__) +// GNU GCC with -Weffc++ will issue a warning regarding the upcoming class, we want to silence it: +// warning: base class 'class std::enable_shared_from_this' has accessible non-virtual destructor +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" +#pragma GCC diagnostic ignored "-Weffc++" +// This will be ignored under other compilers like LLVM clang. +#endif + class UnicodeStringIterator : public + std::iterator + { + public: + + UnicodeStringIterator(const icu::UnicodeString* string, int32_t pos) + : s(string) + , i(pos) + { + } + + value_type + operator*() const + { + return s->char32At(i); + } + + bool + operator==(const UnicodeStringIterator& rhs) const + { + return s == rhs.s && i == rhs.i; + } + + bool + operator!=(const UnicodeStringIterator& rhs) const + { + return !(*this == rhs); + } + + UnicodeStringIterator& + operator++() + { + ++i; + return *this; + } + + UnicodeStringIterator + operator+(int32_t v) + { + return UnicodeStringIterator(s, i + v); + } + + private: + const icu::UnicodeString* s; + int32_t i; + }; +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + + inline + String& + stringAppend(String&s, String a) + { + return s.append(std::move(a)); + } + + inline + String& + stringAppend(String& s, size_t n, UChar32 c) + { + for (size_t i = 0; i != n; ++i) + { + s.append(c); + } + + return s; + } + + template + String& + stringAppend(String& s, Iterator begin, Iterator end) + { + while (begin != end) + { + s.append(*begin); + ++begin; + } + + return s; + } + + inline + size_t + stringLength(const String& s) + { + return s.length(); + } + + inline + std::string + toUTF8String(const String& s) + { + std::string result; + s.toUTF8String(result); + + return result; + } + + inline + bool + empty(const String& s) + { + return s.isEmpty(); + } +} + +namespace std +{ + inline + cxxopts::UnicodeStringIterator + begin(const icu::UnicodeString& s) + { + return cxxopts::UnicodeStringIterator(&s, 0); + } + + inline + cxxopts::UnicodeStringIterator + end(const icu::UnicodeString& s) + { + return cxxopts::UnicodeStringIterator(&s, s.length()); + } +} + +//ifdef CXXOPTS_USE_UNICODE +#else + +namespace cxxopts +{ + using String = std::string; + + template + T + toLocalString(T&& t) + { + return std::forward(t); + } + + inline + size_t + stringLength(const String& s) + { + return s.length(); + } + + inline + String& + stringAppend(String&s, const String& a) + { + return s.append(a); + } + + inline + String& + stringAppend(String& s, size_t n, char c) + { + return s.append(n, c); + } + + template + String& + stringAppend(String& s, Iterator begin, Iterator end) + { + return s.append(begin, end); + } + + template + std::string + toUTF8String(T&& t) + { + return std::forward(t); + } + + inline + bool + empty(const std::string& s) + { + return s.empty(); + } +} // namespace cxxopts + +//ifdef CXXOPTS_USE_UNICODE +#endif + +namespace cxxopts +{ + namespace + { +#ifdef _WIN32 + const std::string LQUOTE("\'"); + const std::string RQUOTE("\'"); +#else + const std::string LQUOTE("‘"); + const std::string RQUOTE("’"); +#endif + } // namespace + +#if defined(__GNUC__) +// GNU GCC with -Weffc++ will issue a warning regarding the upcoming class, we want to silence it: +// warning: base class 'class std::enable_shared_from_this' has accessible non-virtual destructor +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" +#pragma GCC diagnostic ignored "-Weffc++" +// This will be ignored under other compilers like LLVM clang. +#endif + class Value : public std::enable_shared_from_this + { + public: + + virtual ~Value() = default; + + virtual + std::shared_ptr + clone() const = 0; + + virtual void + parse(const std::string& text) const = 0; + + virtual void + parse() const = 0; + + virtual bool + has_default() const = 0; + + virtual bool + is_container() const = 0; + + virtual bool + has_implicit() const = 0; + + virtual std::string + get_default_value() const = 0; + + virtual std::string + get_implicit_value() const = 0; + + virtual std::shared_ptr + default_value(const std::string& value) = 0; + + virtual std::shared_ptr + implicit_value(const std::string& value) = 0; + + virtual std::shared_ptr + no_implicit_value() = 0; + + virtual bool + is_boolean() const = 0; + }; +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + class OptionException : public std::exception + { + public: + explicit OptionException(std::string message) + : m_message(std::move(message)) + { + } + + CXXOPTS_NODISCARD + const char* + what() const noexcept override + { + return m_message.c_str(); + } + + private: + std::string m_message; + }; + + class OptionSpecException : public OptionException + { + public: + + explicit OptionSpecException(const std::string& message) + : OptionException(message) + { + } + }; + + class OptionParseException : public OptionException + { + public: + explicit OptionParseException(const std::string& message) + : OptionException(message) + { + } + }; + + class option_exists_error : public OptionSpecException + { + public: + explicit option_exists_error(const std::string& option) + : OptionSpecException("Option " + LQUOTE + option + RQUOTE + " already exists") + { + } + }; + + class invalid_option_format_error : public OptionSpecException + { + public: + explicit invalid_option_format_error(const std::string& format) + : OptionSpecException("Invalid option format " + LQUOTE + format + RQUOTE) + { + } + }; + + class option_syntax_exception : public OptionParseException { + public: + explicit option_syntax_exception(const std::string& text) + : OptionParseException("Argument " + LQUOTE + text + RQUOTE + + " starts with a - but has incorrect syntax") + { + } + }; + + class option_not_exists_exception : public OptionParseException + { + public: + explicit option_not_exists_exception(const std::string& option) + : OptionParseException("Option " + LQUOTE + option + RQUOTE + " does not exist") + { + } + }; + + class missing_argument_exception : public OptionParseException + { + public: + explicit missing_argument_exception(const std::string& option) + : OptionParseException( + "Option " + LQUOTE + option + RQUOTE + " is missing an argument" + ) + { + } + }; + + class option_requires_argument_exception : public OptionParseException + { + public: + explicit option_requires_argument_exception(const std::string& option) + : OptionParseException( + "Option " + LQUOTE + option + RQUOTE + " requires an argument" + ) + { + } + }; + + class option_not_has_argument_exception : public OptionParseException + { + public: + option_not_has_argument_exception + ( + const std::string& option, + const std::string& arg + ) + : OptionParseException( + "Option " + LQUOTE + option + RQUOTE + + " does not take an argument, but argument " + + LQUOTE + arg + RQUOTE + " given" + ) + { + } + }; + + class option_not_present_exception : public OptionParseException + { + public: + explicit option_not_present_exception(const std::string& option) + : OptionParseException("Option " + LQUOTE + option + RQUOTE + " not present") + { + } + }; + + class option_has_no_value_exception : public OptionException + { + public: + explicit option_has_no_value_exception(const std::string& option) + : OptionException( + !option.empty() ? + ("Option " + LQUOTE + option + RQUOTE + " has no value") : + "Option has no value") + { + } + }; + + class argument_incorrect_type : public OptionParseException + { + public: + explicit argument_incorrect_type + ( + const std::string& arg + ) + : OptionParseException( + "Argument " + LQUOTE + arg + RQUOTE + " failed to parse" + ) + { + } + }; + + class option_required_exception : public OptionParseException + { + public: + explicit option_required_exception(const std::string& option) + : OptionParseException( + "Option " + LQUOTE + option + RQUOTE + " is required but not present" + ) + { + } + }; + + template + void throw_or_mimic(const std::string& text) + { + static_assert(std::is_base_of::value, + "throw_or_mimic only works on std::exception and " + "deriving classes"); + +#ifndef CXXOPTS_NO_EXCEPTIONS + // If CXXOPTS_NO_EXCEPTIONS is not defined, just throw + throw T{text}; +#else + // Otherwise manually instantiate the exception, print what() to stderr, + // and exit + T exception{text}; + std::cerr << exception.what() << std::endl; + std::exit(EXIT_FAILURE); +#endif + } + + namespace values + { + namespace parser_tool + { + struct IntegerDesc + { + std::string negative = ""; + std::string base = ""; + std::string value = ""; + }; + struct ArguDesc { + std::string arg_name = ""; + bool grouping = false; + bool set_value = false; + std::string value = ""; + }; +#ifdef CXXOPTS_NO_REGEX + inline IntegerDesc SplitInteger(const std::string &text) + { + if (text.empty()) + { + throw_or_mimic(text); + } + IntegerDesc desc; + const char *pdata = text.c_str(); + if (*pdata == '-') + { + pdata += 1; + desc.negative = "-"; + } + if (strncmp(pdata, "0x", 2) == 0) + { + pdata += 2; + desc.base = "0x"; + } + if (*pdata != '\0') + { + desc.value = std::string(pdata); + } + else + { + throw_or_mimic(text); + } + return desc; + } + + inline bool IsTrueText(const std::string &text) + { + const char *pdata = text.c_str(); + if (*pdata == 't' || *pdata == 'T') + { + pdata += 1; + if (strncmp(pdata, "rue\0", 4) == 0) + { + return true; + } + } + else if (strncmp(pdata, "1\0", 2) == 0) + { + return true; + } + return false; + } + + inline bool IsFalseText(const std::string &text) + { + const char *pdata = text.c_str(); + if (*pdata == 'f' || *pdata == 'F') + { + pdata += 1; + if (strncmp(pdata, "alse\0", 5) == 0) + { + return true; + } + } + else if (strncmp(pdata, "0\0", 2) == 0) + { + return true; + } + return false; + } + + inline std::pair SplitSwitchDef(const std::string &text) + { + std::string short_sw, long_sw; + const char *pdata = text.c_str(); + if (isalnum(*pdata) && *(pdata + 1) == ',') { + short_sw = std::string(1, *pdata); + pdata += 2; + } + while (*pdata == ' ') { pdata += 1; } + if (isalnum(*pdata)) { + const char *store = pdata; + pdata += 1; + while (isalnum(*pdata) || *pdata == '-' || *pdata == '_') { + pdata += 1; + } + if (*pdata == '\0') { + long_sw = std::string(store, pdata - store); + } else { + throw_or_mimic(text); + } + } + return std::pair(short_sw, long_sw); + } + + inline ArguDesc ParseArgument(const char *arg, bool &matched) + { + ArguDesc argu_desc; + const char *pdata = arg; + matched = false; + if (strncmp(pdata, "--", 2) == 0) + { + pdata += 2; + if (isalnum(*pdata)) + { + argu_desc.arg_name.push_back(*pdata); + pdata += 1; + while (isalnum(*pdata) || *pdata == '-' || *pdata == '_') + { + argu_desc.arg_name.push_back(*pdata); + pdata += 1; + } + if (argu_desc.arg_name.length() > 1) + { + if (*pdata == '=') + { + argu_desc.set_value = true; + pdata += 1; + if (*pdata != '\0') + { + argu_desc.value = std::string(pdata); + } + matched = true; + } + else if (*pdata == '\0') + { + matched = true; + } + } + } + } + else if (strncmp(pdata, "-", 1) == 0) + { + pdata += 1; + argu_desc.grouping = true; + while (isalnum(*pdata)) + { + argu_desc.arg_name.push_back(*pdata); + pdata += 1; + } + matched = !argu_desc.arg_name.empty() && *pdata == '\0'; + } + return argu_desc; + } + +#else // CXXOPTS_NO_REGEX + + namespace + { + + std::basic_regex integer_pattern + ("(-)?(0x)?([0-9a-zA-Z]+)|((0x)?0)"); + std::basic_regex truthy_pattern + ("(t|T)(rue)?|1"); + std::basic_regex falsy_pattern + ("(f|F)(alse)?|0"); + + std::basic_regex option_matcher + ("--([[:alnum:]][-_[:alnum:]]+)(=(.*))?|-([[:alnum:]]+)"); + std::basic_regex option_specifier + ("(([[:alnum:]]),)?[ ]*([[:alnum:]][-_[:alnum:]]*)?"); + + } // namespace + + inline IntegerDesc SplitInteger(const std::string &text) + { + std::smatch match; + std::regex_match(text, match, integer_pattern); + + if (match.length() == 0) + { + throw_or_mimic(text); + } + + IntegerDesc desc; + desc.negative = match[1]; + desc.base = match[2]; + desc.value = match[3]; + + if (match.length(4) > 0) + { + desc.base = match[5]; + desc.value = "0"; + return desc; + } + + return desc; + } + + inline bool IsTrueText(const std::string &text) + { + std::smatch result; + std::regex_match(text, result, truthy_pattern); + return !result.empty(); + } + + inline bool IsFalseText(const std::string &text) + { + std::smatch result; + std::regex_match(text, result, falsy_pattern); + return !result.empty(); + } + + inline std::pair SplitSwitchDef(const std::string &text) + { + std::match_results result; + std::regex_match(text.c_str(), result, option_specifier); + if (result.empty()) + { + throw_or_mimic(text); + } + + const std::string& short_sw = result[2]; + const std::string& long_sw = result[3]; + + return std::pair(short_sw, long_sw); + } + + inline ArguDesc ParseArgument(const char *arg, bool &matched) + { + std::match_results result; + std::regex_match(arg, result, option_matcher); + matched = !result.empty(); + + ArguDesc argu_desc; + if (matched) { + argu_desc.arg_name = result[1].str(); + argu_desc.set_value = result[2].length() > 0; + argu_desc.value = result[3].str(); + if (result[4].length() > 0) + { + argu_desc.grouping = true; + argu_desc.arg_name = result[4].str(); + } + } + + return argu_desc; + } + +#endif // CXXOPTS_NO_REGEX +#undef CXXOPTS_NO_REGEX + } + + namespace detail + { + template + struct SignedCheck; + + template + struct SignedCheck + { + template + void + operator()(bool negative, U u, const std::string& text) + { + if (negative) + { + if (u > static_cast((std::numeric_limits::min)())) + { + throw_or_mimic(text); + } + } + else + { + if (u > static_cast((std::numeric_limits::max)())) + { + throw_or_mimic(text); + } + } + } + }; + + template + struct SignedCheck + { + template + void + operator()(bool, U, const std::string&) const {} + }; + + template + void + check_signed_range(bool negative, U value, const std::string& text) + { + SignedCheck::is_signed>()(negative, value, text); + } + } // namespace detail + + template + void + checked_negate(R& r, T&& t, const std::string&, std::true_type) + { + // if we got to here, then `t` is a positive number that fits into + // `R`. So to avoid MSVC C4146, we first cast it to `R`. + // See https://github.com/jarro2783/cxxopts/issues/62 for more details. + r = static_cast(-static_cast(t-1)-1); + } + + template + void + checked_negate(R&, T&&, const std::string& text, std::false_type) + { + throw_or_mimic(text); + } + + template + void + integer_parser(const std::string& text, T& value) + { + parser_tool::IntegerDesc int_desc = parser_tool::SplitInteger(text); + + using US = typename std::make_unsigned::type; + constexpr bool is_signed = std::numeric_limits::is_signed; + + const bool negative = int_desc.negative.length() > 0; + const uint8_t base = int_desc.base.length() > 0 ? 16 : 10; + const std::string & value_match = int_desc.value; + + US result = 0; + + for (char ch : value_match) + { + US digit = 0; + + if (ch >= '0' && ch <= '9') + { + digit = static_cast(ch - '0'); + } + else if (base == 16 && ch >= 'a' && ch <= 'f') + { + digit = static_cast(ch - 'a' + 10); + } + else if (base == 16 && ch >= 'A' && ch <= 'F') + { + digit = static_cast(ch - 'A' + 10); + } + else + { + throw_or_mimic(text); + } + + const US next = static_cast(result * base + digit); + if (result > next) + { + throw_or_mimic(text); + } + + result = next; + } + + detail::check_signed_range(negative, result, text); + + if (negative) + { + checked_negate(value, result, text, std::integral_constant()); + } + else + { + value = static_cast(result); + } + } + + template + void stringstream_parser(const std::string& text, T& value) + { + std::stringstream in(text); + in >> value; + if (!in) { + throw_or_mimic(text); + } + } + + template ::value>::type* = nullptr + > + void parse_value(const std::string& text, T& value) + { + integer_parser(text, value); + } + + inline + void + parse_value(const std::string& text, bool& value) + { + if (parser_tool::IsTrueText(text)) + { + value = true; + return; + } + + if (parser_tool::IsFalseText(text)) + { + value = false; + return; + } + + throw_or_mimic(text); + } + + inline + void + parse_value(const std::string& text, std::string& value) + { + value = text; + } + + // The fallback parser. It uses the stringstream parser to parse all types + // that have not been overloaded explicitly. It has to be placed in the + // source code before all other more specialized templates. + template ::value>::type* = nullptr + > + void + parse_value(const std::string& text, T& value) { + stringstream_parser(text, value); + } + + template + void + parse_value(const std::string& text, std::vector& value) + { + if (text.empty()) { + T v; + parse_value(text, v); + value.emplace_back(std::move(v)); + return; + } + std::stringstream in(text); + std::string token; + while(!in.eof() && std::getline(in, token, CXXOPTS_VECTOR_DELIMITER)) { + T v; + parse_value(token, v); + value.emplace_back(std::move(v)); + } + } + +#ifdef CXXOPTS_HAS_OPTIONAL + template + void + parse_value(const std::string& text, std::optional& value) + { + T result; + parse_value(text, result); + value = std::move(result); + } +#endif + + inline + void parse_value(const std::string& text, char& c) + { + if (text.length() != 1) + { + throw_or_mimic(text); + } + + c = text[0]; + } + + template + struct type_is_container + { + static constexpr bool value = false; + }; + + template + struct type_is_container> + { + static constexpr bool value = true; + }; + + template + class abstract_value : public Value + { + using Self = abstract_value; + + public: + abstract_value() + : m_result(std::make_shared()) + , m_store(m_result.get()) + { + } + + explicit abstract_value(T* t) + : m_store(t) + { + } + + ~abstract_value() override = default; + + abstract_value& operator=(const abstract_value&) = default; + + abstract_value(const abstract_value& rhs) + { + if (rhs.m_result) + { + m_result = std::make_shared(); + m_store = m_result.get(); + } + else + { + m_store = rhs.m_store; + } + + m_default = rhs.m_default; + m_implicit = rhs.m_implicit; + m_default_value = rhs.m_default_value; + m_implicit_value = rhs.m_implicit_value; + } + + void + parse(const std::string& text) const override + { + parse_value(text, *m_store); + } + + bool + is_container() const override + { + return type_is_container::value; + } + + void + parse() const override + { + parse_value(m_default_value, *m_store); + } + + bool + has_default() const override + { + return m_default; + } + + bool + has_implicit() const override + { + return m_implicit; + } + + std::shared_ptr + default_value(const std::string& value) override + { + m_default = true; + m_default_value = value; + return shared_from_this(); + } + + std::shared_ptr + implicit_value(const std::string& value) override + { + m_implicit = true; + m_implicit_value = value; + return shared_from_this(); + } + + std::shared_ptr + no_implicit_value() override + { + m_implicit = false; + return shared_from_this(); + } + + std::string + get_default_value() const override + { + return m_default_value; + } + + std::string + get_implicit_value() const override + { + return m_implicit_value; + } + + bool + is_boolean() const override + { + return std::is_same::value; + } + + const T& + get() const + { + if (m_store == nullptr) + { + return *m_result; + } + return *m_store; + } + + protected: + std::shared_ptr m_result{}; + T* m_store{}; + + bool m_default = false; + bool m_implicit = false; + + std::string m_default_value{}; + std::string m_implicit_value{}; + }; + + template + class standard_value : public abstract_value + { + public: + using abstract_value::abstract_value; + + CXXOPTS_NODISCARD + std::shared_ptr + clone() const override + { + return std::make_shared>(*this); + } + }; + + template <> + class standard_value : public abstract_value + { + public: + ~standard_value() override = default; + + standard_value() + { + set_default_and_implicit(); + } + + explicit standard_value(bool* b) + : abstract_value(b) + { + set_default_and_implicit(); + } + + std::shared_ptr + clone() const override + { + return std::make_shared>(*this); + } + + private: + + void + set_default_and_implicit() + { + m_default = true; + m_default_value = "false"; + m_implicit = true; + m_implicit_value = "true"; + } + }; + } // namespace values + + template + std::shared_ptr + value() + { + return std::make_shared>(); + } + + template + std::shared_ptr + value(T& t) + { + return std::make_shared>(&t); + } + + class OptionAdder; + + class OptionDetails + { + public: + OptionDetails + ( + std::string short_, + std::string long_, + String desc, + std::shared_ptr val + ) + : m_short(std::move(short_)) + , m_long(std::move(long_)) + , m_desc(std::move(desc)) + , m_value(std::move(val)) + , m_count(0) + { + m_hash = std::hash{}(m_long + m_short); + } + + OptionDetails(const OptionDetails& rhs) + : m_desc(rhs.m_desc) + , m_value(rhs.m_value->clone()) + , m_count(rhs.m_count) + { + } + + OptionDetails(OptionDetails&& rhs) = default; + + CXXOPTS_NODISCARD + const String& + description() const + { + return m_desc; + } + + CXXOPTS_NODISCARD + const Value& + value() const { + return *m_value; + } + + CXXOPTS_NODISCARD + std::shared_ptr + make_storage() const + { + return m_value->clone(); + } + + CXXOPTS_NODISCARD + const std::string& + short_name() const + { + return m_short; + } + + CXXOPTS_NODISCARD + const std::string& + long_name() const + { + return m_long; + } + + size_t + hash() const + { + return m_hash; + } + + private: + std::string m_short{}; + std::string m_long{}; + String m_desc{}; + std::shared_ptr m_value{}; + int m_count; + + size_t m_hash{}; + }; + + struct HelpOptionDetails + { + std::string s; + std::string l; + String desc; + bool has_default; + std::string default_value; + bool has_implicit; + std::string implicit_value; + std::string arg_help; + bool is_container; + bool is_boolean; + }; + + struct HelpGroupDetails + { + std::string name{}; + std::string description{}; + std::vector options{}; + }; + + class OptionValue + { + public: + void + parse + ( + const std::shared_ptr& details, + const std::string& text + ) + { + ensure_value(details); + ++m_count; + m_value->parse(text); + m_long_name = &details->long_name(); + } + + void + parse_default(const std::shared_ptr& details) + { + ensure_value(details); + m_default = true; + m_long_name = &details->long_name(); + m_value->parse(); + } + + void + parse_no_value(const std::shared_ptr& details) + { + m_long_name = &details->long_name(); + } + +#if defined(CXXOPTS_NULL_DEREF_IGNORE) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnull-dereference" +#endif + + CXXOPTS_NODISCARD + size_t + count() const noexcept + { + return m_count; + } + +#if defined(CXXOPTS_NULL_DEREF_IGNORE) +#pragma GCC diagnostic pop +#endif + + // TODO: maybe default options should count towards the number of arguments + CXXOPTS_NODISCARD + bool + has_default() const noexcept + { + return m_default; + } + + template + const T& + as() const + { + if (m_value == nullptr) { + throw_or_mimic( + m_long_name == nullptr ? "" : *m_long_name); + } + +#ifdef CXXOPTS_NO_RTTI + return static_cast&>(*m_value).get(); +#else + return dynamic_cast&>(*m_value).get(); +#endif + } + + private: + void + ensure_value(const std::shared_ptr& details) + { + if (m_value == nullptr) + { + m_value = details->make_storage(); + } + } + + + const std::string* m_long_name = nullptr; + // Holding this pointer is safe, since OptionValue's only exist in key-value pairs, + // where the key has the string we point to. + std::shared_ptr m_value{}; + size_t m_count = 0; + bool m_default = false; + }; + + class KeyValue + { + public: + KeyValue(std::string key_, std::string value_) + : m_key(std::move(key_)) + , m_value(std::move(value_)) + { + } + + CXXOPTS_NODISCARD + const std::string& + key() const + { + return m_key; + } + + CXXOPTS_NODISCARD + const std::string& + value() const + { + return m_value; + } + + template + T + as() const + { + T result; + values::parse_value(m_value, result); + return result; + } + + private: + std::string m_key; + std::string m_value; + }; + + using ParsedHashMap = std::unordered_map; + using NameHashMap = std::unordered_map; + + class ParseResult + { + public: + class Iterator + { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = KeyValue; + using difference_type = void; + using pointer = const KeyValue*; + using reference = const KeyValue&; + + Iterator() = default; + Iterator(const Iterator&) = default; + + Iterator(const ParseResult *pr, bool end=false) + : m_pr(pr) + , m_iter(end? pr->m_defaults.end(): pr->m_sequential.begin()) + { + } + + Iterator& operator++() + { + ++m_iter; + if(m_iter == m_pr->m_sequential.end()) + { + m_iter = m_pr->m_defaults.begin(); + return *this; + } + return *this; + } + + Iterator operator++(int) + { + Iterator retval = *this; + ++(*this); + return retval; + } + + bool operator==(const Iterator& other) const + { + return m_iter == other.m_iter; + } + + bool operator!=(const Iterator& other) const + { + return !(*this == other); + } + + const KeyValue& operator*() + { + return *m_iter; + } + + const KeyValue* operator->() + { + return m_iter.operator->(); + } + + private: + const ParseResult* m_pr; + std::vector::const_iterator m_iter; + }; + + ParseResult() = default; + ParseResult(const ParseResult&) = default; + + ParseResult(NameHashMap&& keys, ParsedHashMap&& values, std::vector sequential, + std::vector default_opts, std::vector&& unmatched_args) + : m_keys(std::move(keys)) + , m_values(std::move(values)) + , m_sequential(std::move(sequential)) + , m_defaults(std::move(default_opts)) + , m_unmatched(std::move(unmatched_args)) + { + } + + ParseResult& operator=(ParseResult&&) = default; + ParseResult& operator=(const ParseResult&) = default; + + Iterator + begin() const + { + return Iterator(this); + } + + Iterator + end() const + { + return Iterator(this, true); + } + + size_t + count(const std::string& o) const + { + auto iter = m_keys.find(o); + if (iter == m_keys.end()) + { + return 0; + } + + auto viter = m_values.find(iter->second); + + if (viter == m_values.end()) + { + return 0; + } + + return viter->second.count(); + } + + const OptionValue& + operator[](const std::string& option) const + { + auto iter = m_keys.find(option); + + if (iter == m_keys.end()) + { + throw_or_mimic(option); + } + + auto viter = m_values.find(iter->second); + + if (viter == m_values.end()) + { + throw_or_mimic(option); + } + + return viter->second; + } + + const std::vector& + arguments() const + { + return m_sequential; + } + + const std::vector& + unmatched() const + { + return m_unmatched; + } + + const std::vector& + defaults() const + { + return m_defaults; + } + + const std::string + arguments_string() const + { + std::string result; + for(const auto& kv: m_sequential) + { + result += kv.key() + " = " + kv.value() + "\n"; + } + for(const auto& kv: m_defaults) + { + result += kv.key() + " = " + kv.value() + " " + "(default)" + "\n"; + } + return result; + } + + private: + NameHashMap m_keys{}; + ParsedHashMap m_values{}; + std::vector m_sequential{}; + std::vector m_defaults{}; + std::vector m_unmatched{}; + }; + + struct Option + { + Option + ( + std::string opts, + std::string desc, + std::shared_ptr value = ::cxxopts::value(), + std::string arg_help = "" + ) + : opts_(std::move(opts)) + , desc_(std::move(desc)) + , value_(std::move(value)) + , arg_help_(std::move(arg_help)) + { + } + + std::string opts_; + std::string desc_; + std::shared_ptr value_; + std::string arg_help_; + }; + + using OptionMap = std::unordered_map>; + using PositionalList = std::vector; + using PositionalListIterator = PositionalList::const_iterator; + + class OptionParser + { + public: + OptionParser(const OptionMap& options, const PositionalList& positional, bool allow_unrecognised) + : m_options(options) + , m_positional(positional) + , m_allow_unrecognised(allow_unrecognised) + { + } + + ParseResult + parse(int argc, const char* const* argv); + + bool + consume_positional(const std::string& a, PositionalListIterator& next); + + void + checked_parse_arg + ( + int argc, + const char* const* argv, + int& current, + const std::shared_ptr& value, + const std::string& name + ); + + void + add_to_option(OptionMap::const_iterator iter, const std::string& option, const std::string& arg); + + void + parse_option + ( + const std::shared_ptr& value, + const std::string& name, + const std::string& arg = "" + ); + + void + parse_default(const std::shared_ptr& details); + + void + parse_no_value(const std::shared_ptr& details); + + private: + + void finalise_aliases(); + + const OptionMap& m_options; + const PositionalList& m_positional; + + std::vector m_sequential{}; + std::vector m_defaults{}; + bool m_allow_unrecognised; + + ParsedHashMap m_parsed{}; + NameHashMap m_keys{}; + }; + + class Options + { + public: + + explicit Options(std::string program, std::string help_string = "") + : m_program(std::move(program)) + , m_help_string(toLocalString(std::move(help_string))) + , m_custom_help("[OPTION...]") + , m_positional_help("positional parameters") + , m_show_positional(false) + , m_allow_unrecognised(false) + , m_width(76) + , m_tab_expansion(false) + , m_options(std::make_shared()) + { + } + + Options& + positional_help(std::string help_text) + { + m_positional_help = std::move(help_text); + return *this; + } + + Options& + custom_help(std::string help_text) + { + m_custom_help = std::move(help_text); + return *this; + } + + Options& + show_positional_help() + { + m_show_positional = true; + return *this; + } + + Options& + allow_unrecognised_options() + { + m_allow_unrecognised = true; + return *this; + } + + Options& + set_width(size_t width) + { + m_width = width; + return *this; + } + + Options& + set_tab_expansion(bool expansion=true) + { + m_tab_expansion = expansion; + return *this; + } + + ParseResult + parse(int argc, const char* const* argv); + + OptionAdder + add_options(std::string group = ""); + + void + add_options + ( + const std::string& group, + std::initializer_list