Skip to content

Commit

Permalink
Additional testing for relu/relup.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Nov 5, 2023
1 parent 0a28558 commit abd3c0c
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions test/relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <heyoka/config.hpp>

#include <functional>
#include <limits>
#include <random>
#include <sstream>
Expand Down Expand Up @@ -167,6 +168,32 @@ TEST_CASE("names")
}
}

// Test to check that equality, hashing and less-than, which take into account
// the function name, behave correctly when changing slope.
TEST_CASE("hash eq lt")
{
auto [x, y] = make_vars("x", "y");

REQUIRE(relu(x + y) != relu(x + y, 0.01));
REQUIRE(relup(x + y) != relup(x + y, 0.01));
REQUIRE(relu(x + y, 0.02) != relu(x + y, 0.01));
REQUIRE(relup(x + y, 0.02) != relup(x + y, 0.01));
REQUIRE((std::get<func>(relu(x + y).value()) < std::get<func>(relu(x + y, 0.01).value())
|| std::get<func>(relu(x + y, 0.01).value()) < std::get<func>(relu(x + y).value())));
REQUIRE((std::get<func>(relup(x + y).value()) < std::get<func>(relup(x + y, 0.01).value())
|| std::get<func>(relup(x + y, 0.01).value()) < std::get<func>(relup(x + y).value())));
REQUIRE((std::get<func>(relu(x + y, 0.02).value()) < std::get<func>(relu(x + y, 0.01).value())
|| std::get<func>(relu(x + y, 0.01).value()) < std::get<func>(relu(x + y, 0.02).value())));
REQUIRE((std::get<func>(relup(x + y, 0.02).value()) < std::get<func>(relup(x + y, 0.01).value())
|| std::get<func>(relup(x + y, 0.01).value()) < std::get<func>(relup(x + y, 0.02).value())));

// Of course, not 100% guaranteed but hopefully very likely.
REQUIRE(std::hash<expression>{}(relu(x + y)) != std::hash<expression>{}(relu(x + y, 0.01)));
REQUIRE(std::hash<expression>{}(relup(x + y)) != std::hash<expression>{}(relup(x + y, 0.01)));
REQUIRE(std::hash<expression>{}(relu(x + y, 0.02)) != std::hash<expression>{}(relu(x + y, 0.01)));
REQUIRE(std::hash<expression>{}(relup(x + y, 0.02)) != std::hash<expression>{}(relup(x + y, 0.01)));
}

TEST_CASE("invalid slopes")
{
using Catch::Matchers::Message;
Expand Down

0 comments on commit abd3c0c

Please sign in to comment.