Skip to content

Commit

Permalink
Initial implementation/testing for relu() and relup().
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Nov 1, 2023
1 parent 2a0d493 commit 18f84a8
Show file tree
Hide file tree
Showing 7 changed files with 968 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ set(HEYOKA_SRC_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/log.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/pow.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sigmoid.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/relu.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sin.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/sqrt.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/tan.cpp"
Expand Down
1 change: 1 addition & 0 deletions include/heyoka/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <heyoka/math/log.hpp>
#include <heyoka/math/pow.hpp>
#include <heyoka/math/prod.hpp>
#include <heyoka/math/relu.hpp>
#include <heyoka/math/sigmoid.hpp>
#include <heyoka/math/sin.hpp>
#include <heyoka/math/sinh.hpp>
Expand Down
99 changes: 99 additions & 0 deletions include/heyoka/math/relu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2020, 2021, 2022, 2023 Francesco Biscani ([email protected]), Dario Izzo ([email protected])
//
// This file is part of the heyoka library.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef HEYOKA_MATH_RELU_HPP
#define HEYOKA_MATH_RELU_HPP

#include <cstdint>
// #include <string>
#include <vector>

#include <heyoka/config.hpp>
// #include <heyoka/detail/func_cache.hpp>
#include <heyoka/detail/fwd_decl.hpp>
#include <heyoka/detail/llvm_fwd.hpp>
#include <heyoka/detail/visibility.hpp>
#include <heyoka/func.hpp>
#include <heyoka/s11n.hpp>

HEYOKA_BEGIN_NAMESPACE

namespace detail
{

class HEYOKA_DLL_PUBLIC relu_impl : public func_base
{
friend class boost::serialization::access;
template <typename Archive>
void serialize(Archive &ar, unsigned)
{
ar &boost::serialization::base_object<func_base>(*this);
}

public:
relu_impl();
explicit relu_impl(expression);

[[nodiscard]] expression normalise() const;

[[nodiscard]] std::vector<expression> gradient() const;

[[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector<llvm::Value *> &, llvm::Value *,
llvm::Value *, llvm::Value *, std::uint32_t, bool) const;

[[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const;

llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t, bool) const;

llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, bool) const;
};

class HEYOKA_DLL_PUBLIC relup_impl : public func_base
{
friend class boost::serialization::access;
template <typename Archive>
void serialize(Archive &ar, unsigned)
{
ar &boost::serialization::base_object<func_base>(*this);
}

public:
relup_impl();
explicit relup_impl(expression);

[[nodiscard]] expression normalise() const;

[[nodiscard]] std::vector<expression> gradient() const;

[[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector<llvm::Value *> &, llvm::Value *,
llvm::Value *, llvm::Value *, std::uint32_t, bool) const;

[[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const;

llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector<std::uint32_t> &,
const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *, std::uint32_t,
std::uint32_t, std::uint32_t, std::uint32_t, bool) const;

llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, bool) const;
};

} // namespace detail

HEYOKA_DLL_PUBLIC expression relu(expression);

HEYOKA_DLL_PUBLIC expression relup(expression);

HEYOKA_END_NAMESPACE

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::relu_impl)

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::relup_impl)

#endif
Loading

0 comments on commit 18f84a8

Please sign in to comment.