From 9aec85dc4bf89f2e67dccdf4ae72df3e8d88aa8d Mon Sep 17 00:00:00 2001 From: Joseph Tan Date: Tue, 13 Jul 2021 00:36:52 -0700 Subject: [PATCH] Added support for arithmetic with different types Adds support for operations like int16_t + int64_t or uint8_t + double(the operand with the smaller size is promoted to the operand with the larger size) and tests for this new functionality. Also adds a few helper functions including StaticCastOrConvertLiteral, CastOperandAndDoArithmetic, and CastOperandAndDoComparison. o Also adds a GetAs member to the AstLiteralExpression to get the value of a literal as type T --- pochivm/api_base.h | 428 ++++++++++++++-------- pochivm/ast_type_helper.h | 36 ++ pochivm/common_expr.h | 17 + pochivm/for_each_primitive_type.h | 17 + test_sanity_arith_expr.cpp | 571 +++++++++++++++++------------- 5 files changed, 685 insertions(+), 384 deletions(-) diff --git a/pochivm/api_base.h b/pochivm/api_base.h index 86b15cf..9be542f 100644 --- a/pochivm/api_base.h +++ b/pochivm/api_base.h @@ -115,113 +115,189 @@ Value::operator Value() const return StaticCast(*this); } +// Language utility: construct a literal +// Example: Literal(1) +// +template +Value Literal(T x) +{ + return Value(new AstLiteralExpr(TypeId::Get(), &x)); +} + +// If src is a literal expression, return a new equivalent literal expression of type U. Otherwise, return +// a StaticCast of src to U +// +template::value> > +Value StaticCastOrConvertLiteral(const Value& src) +{ + static_assert(AstTypeHelper::may_static_cast::value, "cannot static_cast T to U"); + if(src.__pochivm_value_ptr->GetAstNodeType() == AstNodeType::AstLiteralExpr) + { + return Literal(static_cast(assert_cast(src.__pochivm_value_ptr)->template GetAs())); + } + else + { + return StaticCast(src); + } +} + +// Helper for arithmetic ops operator overloading. Returns an arithmetic expression of the form +// lhs OP rhs where OP is the operator specified by `expr_type`. If the operands have different +// types, casts the value of the 'smaller' type to the 'larger' type. +// Identical to the C implicit promotion rules except doesn't necessarily cast from types smaller +// than the int type to the int type. Both types must have same signedess. +// +template ::value>, + typename = std::enable_if_t::value> > +Value::type> CastOperandAndDoArithmetic(const Value& lhs, const Value& rhs) +{ + using ReturnType = typename AstTypeHelper::ArithReturnType::type; + static_assert(std::is_signed::value == std::is_signed::value || + std::is_floating_point::value, + "cannot add two values of different signedness"); + if constexpr (!std::is_same::value) + { + static_assert(std::is_same::value, "internal bug: rhs type is not the same as return type"); + return Value(new AstArithmeticExpr(expr_type, + StaticCastOrConvertLiteral(lhs).__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + } + else if constexpr (!std::is_same::value) + { + static_assert(std::is_same::value, "internal bug: lhs type is not the same as return type"); + return Value(new AstArithmeticExpr(expr_type, + lhs.__pochivm_value_ptr, StaticCastOrConvertLiteral(rhs).__pochivm_value_ptr)); + } + else + { + static_assert(std::is_same::value && std::is_same::value, "internal bug: type of lhs and rhs aren't the same as return type"); + return Value(new AstArithmeticExpr(expr_type, + lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + } +} + // Arithmetic ops convenience operator overloading // -template::value> > -Value operator+(const Value& lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator+(const Value& lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::ADD, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(lhs, rhs); } -template::value> > -Value operator-(const Value& lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator-(const Value& lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::SUB, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(lhs, rhs); } -template::value> > -Value operator*(const Value& lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator*(const Value& lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::MUL, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(lhs, rhs); } -template::value> > -Value operator/(const Value& lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator/(const Value& lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::DIV, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(lhs, rhs); } -template::value> > -Value operator%(const Value& lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator%(const Value& lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::MOD, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(lhs, rhs); } // Convenience overloading: arithmetic operation with literal // -template::value> > -Value operator+(const Value& lhs, T rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator+(const Value& lhs, U rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::ADD, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoArithmetic(lhs, Literal(rhs)); } -template::value> > -Value operator-(const Value& lhs, T rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator-(const Value& lhs, U rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::SUB, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoArithmetic(lhs, Literal(rhs)); } -template::value> > -Value operator*(const Value& lhs, T rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator*(const Value& lhs, U rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::MUL, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoArithmetic(lhs, Literal(rhs)); } -template::value> > -Value operator/(const Value& lhs, T rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator/(const Value& lhs, U rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::DIV, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoArithmetic(lhs, Literal(rhs)); } -template::value> > -Value operator%(const Value& lhs, T rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator%(const Value& lhs, U rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::MOD, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoArithmetic(lhs, Literal(rhs)); } -template::value> > -Value operator+(T lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator+(T lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::ADD, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(Literal(lhs), rhs); } -template::value> > -Value operator-(T lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator-(T lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::SUB, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(Literal(lhs), rhs); } -template::value> > -Value operator*(T lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator*(T lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::MUL, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(Literal(lhs), rhs); } -template::value> > -Value operator/(T lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator/(T lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::DIV, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(Literal(lhs), rhs); } -template::value> > -Value operator%(T lhs, const Value& rhs) +template ::value>, + typename = std::enable_if_t::value> > +Value::type> operator%(T lhs, const Value& rhs) { - return Value(new AstArithmeticExpr(AstArithmeticExprType::MOD, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoArithmetic(Literal(lhs), rhs); } // Pointer arithmetic ops convenience operator overloading @@ -260,142 +336,209 @@ Value operator-(const Value& base, I index) return Value(new AstPointerArithmeticExpr(base.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &index), false /*isAddition*/)); } +// Helper for comparison ops operator overloading. Returns an comparison expression of the form +// lhs OP rhs where OP is the operator specified by `expr_type`. If the operands have different +// types, casts the value of the 'smaller' type to the 'larger' type. +// Identical to the C implicit promotion rules except doesn't necessarily cast from types smaller +// than the int type to the int type. Both types must have same signedess. Cannot compare a non-bool to a bool. +// +template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value == std::is_same::value> > +Value CastOperandAndDoComparison(const Value &lhs, const Value &rhs) +{ + using CommonType = typename AstTypeHelper::ArithReturnType::type; + static_assert(std::is_signed::value == std::is_signed::value || + std::is_floating_point::value, + "cannot compare two values of different signedness"); + if constexpr (!std::is_same::value) + { + static_assert(std::is_same::value, "internal bug: rhs type is not the same as return type"); + return Value(new AstComparisonExpr(expr_type, + StaticCastOrConvertLiteral(lhs).__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + } + else if constexpr (!std::is_same::value) + { + static_assert(std::is_same::value, "internal bug: lhs type is not the same as return type"); + return Value(new AstComparisonExpr(expr_type, + lhs.__pochivm_value_ptr, StaticCastOrConvertLiteral(rhs).__pochivm_value_ptr)); + } + else + { + static_assert(std::is_same::value && std::is_same::value, "internal bug: type of lhs and rhs aren't the same as return type"); + return Value(new AstComparisonExpr(expr_type, + lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + } +} + // Comparison ops convenience operator overloading // -template::value> > -Value operator==(const Value& lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator==(const Value& lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::EQUAL, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + static_assert(std::is_same::value == std::is_same::value, "cannot compare a nonbool to a bool"); // Ensure we're not comparing a bool to a non-bool + return CastOperandAndDoComparison(lhs, rhs); } -template::value> > -Value operator!=(const Value& lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator!=(const Value& lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::NOT_EQUAL, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + static_assert(std::is_same::value == std::is_same::value, "cannot compare a nonbool to a bool"); // Ensure we're not comparing a bool to a non-bool + return CastOperandAndDoComparison(lhs, rhs); } -template::value> > -Value operator<(const Value& lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator<(const Value& lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::LESS_THAN, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(lhs, rhs); } -template::value> > -Value operator>(const Value& lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator>(const Value& lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::GREATER_THAN, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(lhs, rhs); } -template::value - && AstTypeHelper::primitive_type_supports_binary_op::value)> > -Value operator<=(const Value& lhs, const Value& rhs) +template::value + && AstTypeHelper::primitive_type_supports_binary_op::value)>, + typename = std::enable_if_t<(AstTypeHelper::primitive_type_supports_binary_op::value + && AstTypeHelper::primitive_type_supports_binary_op::value)> > +Value operator<=(const Value& lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::LESS_EQUAL, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(lhs, rhs); } -template::value - && AstTypeHelper::primitive_type_supports_binary_op::value)> > -Value operator>=(const Value& lhs, const Value& rhs) +template::value + && AstTypeHelper::primitive_type_supports_binary_op::value)>, + typename = std::enable_if_t<(AstTypeHelper::primitive_type_supports_binary_op::value + && AstTypeHelper::primitive_type_supports_binary_op::value)> > +Value operator>=(const Value& lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::GREATER_EQUAL, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(lhs, rhs); } // Convenience overloading: comparing with a literal // -template::value> > -Value operator==(const Value& lhs, T rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator==(const Value& lhs, U rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::EQUAL, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + static_assert(std::is_same::value == std::is_same::value, "cannot compare a nonbool to a bool"); // Ensure we're not comparing a bool to a non-bool + return CastOperandAndDoComparison(lhs, Literal(rhs)); } -template::value> > -Value operator!=(const Value& lhs, T rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator!=(const Value& lhs, U rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::NOT_EQUAL, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + static_assert(std::is_same::value == std::is_same::value, "cannot compare a nonbool to a bool"); // Ensure we're not comparing a bool to a non-bool + return CastOperandAndDoComparison(lhs, Literal(rhs)); } -template::value> > -Value operator<(const Value& lhs, T rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator<(const Value& lhs, U rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::LESS_THAN, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoComparison(lhs, Literal(rhs)); } -template::value> > -Value operator>(const Value& lhs, T rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator>(const Value& lhs, U rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::GREATER_THAN, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoComparison(lhs, Literal(rhs)); } -template::value - && AstTypeHelper::primitive_type_supports_binary_op::value)> > -Value operator<=(const Value& lhs, T rhs) +template::value + && AstTypeHelper::primitive_type_supports_binary_op::value)>, + typename = std::enable_if_t<(AstTypeHelper::primitive_type_supports_binary_op::value + && AstTypeHelper::primitive_type_supports_binary_op::value)> > +Value operator<=(const Value& lhs, U rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::LESS_EQUAL, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoComparison(lhs, Literal(rhs)); } -template::value - && AstTypeHelper::primitive_type_supports_binary_op::value)> > -Value operator>=(const Value& lhs, T rhs) +template::value + && AstTypeHelper::primitive_type_supports_binary_op::value)>, + typename = std::enable_if_t<(AstTypeHelper::primitive_type_supports_binary_op::value + && AstTypeHelper::primitive_type_supports_binary_op::value)> > +Value operator>=(const Value& lhs, U rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::GREATER_EQUAL, lhs.__pochivm_value_ptr, new AstLiteralExpr(TypeId::Get(), &rhs))); + return CastOperandAndDoComparison(lhs, Literal(rhs)); } -template::value> > -Value operator==(T lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator==(T lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::EQUAL, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + static_assert(std::is_same::value == std::is_same::value, "cannot compare a nonbool to a bool"); // Ensure we're not comparing a bool to a non-bool + return CastOperandAndDoComparison(Literal(lhs), rhs); } -template::value> > -Value operator!=(T lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator!=(T lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::NOT_EQUAL, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + static_assert(std::is_same::value == std::is_same::value, "cannot compare a nonbool to a bool"); // Ensure we're not comparing a bool to a non-bool + return CastOperandAndDoComparison(Literal(lhs), rhs); } -template::value> > -Value operator<(T lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator<(T lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::LESS_THAN, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(Literal(lhs), rhs); } -template::value> > -Value operator>(T lhs, const Value& rhs) +template::value>, + typename = std::enable_if_t::value> > +Value operator>(T lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::GREATER_THAN, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(Literal(lhs), rhs); } -template::value - && AstTypeHelper::primitive_type_supports_binary_op::value)> > -Value operator<=(T lhs, const Value& rhs) +template::value + && AstTypeHelper::primitive_type_supports_binary_op::value)>, + typename = std::enable_if_t<(AstTypeHelper::primitive_type_supports_binary_op::value + && AstTypeHelper::primitive_type_supports_binary_op::value)> > +Value operator<=(T lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::LESS_EQUAL, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(Literal(lhs), rhs); } -template::value - && AstTypeHelper::primitive_type_supports_binary_op::value)> > -Value operator>=(T lhs, const Value& rhs) +template::value + && AstTypeHelper::primitive_type_supports_binary_op::value)>, + typename = std::enable_if_t<(AstTypeHelper::primitive_type_supports_binary_op::value + && AstTypeHelper::primitive_type_supports_binary_op::value)> > +Value operator>=(T lhs, const Value& rhs) { - return Value(new AstComparisonExpr(AstComparisonExprType::GREATER_EQUAL, new AstLiteralExpr(TypeId::Get(), &lhs), rhs.__pochivm_value_ptr)); + return CastOperandAndDoComparison(Literal(lhs), rhs); } + inline Value operator&&(const Value& lhs, const Value& rhs) { return Value(new AstLogicalAndOrExpr(true /*isAnd*/, lhs.__pochivm_value_ptr, rhs.__pochivm_value_ptr)); @@ -441,15 +584,6 @@ Value::operator ConstPrimitiveReference() const return ConstPrimitiveReference(new AstRvalueToConstPrimitiveRefExpr(__pochivm_value_ptr)); } -// Language utility: construct a literal -// Example: Literal(1) -// -template -Value Literal(T x) -{ - return Value(new AstLiteralExpr(TypeId::Get(), &x)); -} - // Language utility: construct a nullptr (allowed for comparison, but disallowed for static_cast) // Example: Nullptr() // diff --git a/pochivm/ast_type_helper.h b/pochivm/ast_type_helper.h index 481264f..3c881eb 100644 --- a/pochivm/ast_type_helper.h +++ b/pochivm/ast_type_helper.h @@ -18,6 +18,18 @@ namespace PochiVM namespace AstTypeHelper { +// Pochi implicit conversions for arithmetic differ slightly from C++. +// C++ promotes all integers smaller than sizeof(int) to the signed int type before +// adding. Pochi just implicitly converts the "smaller" type to the "larger" type +// of the expression where the type sizing is defined as +// (u)int8_t < (u)int16_t < (u)int32_t < (u)int64_t < float < double. +// +template +struct ArithReturnType { + using type = typename std::conditional::type, + typename std::common_type::type>::type; +}; // Give each non-pointer type a unique label // enum AstTypeLabelEnum @@ -593,6 +605,30 @@ struct primitive_type_supports_binary_op : std::integral_constant::value & (static_cast(1) << static_cast(op))) != 0) > {}; +template +struct primitive_type_supports_arithmetic_expr_type : std::integral_constant::value && + ((expr_type == AstArithmeticExprType::ADD && primitive_type_supports_binary_op::value) || + (expr_type == AstArithmeticExprType::SUB && primitive_type_supports_binary_op::value) || + (expr_type == AstArithmeticExprType::MUL && primitive_type_supports_binary_op::value) || + (expr_type == AstArithmeticExprType::DIV && primitive_type_supports_binary_op::value) || + (expr_type == AstArithmeticExprType::MOD && primitive_type_supports_binary_op::value)) +> {}; + +template +struct primitive_type_supports_comparison_expr_type : std::integral_constant::value && + ((expr_type == AstComparisonExprType::EQUAL && primitive_type_supports_binary_op::value) || + (expr_type == AstComparisonExprType::NOT_EQUAL && primitive_type_supports_binary_op::value) || + (expr_type == AstComparisonExprType::LESS_THAN && primitive_type_supports_binary_op::value) || + (expr_type == AstComparisonExprType::GREATER_THAN && primitive_type_supports_binary_op::value) || + (expr_type == AstComparisonExprType::LESS_EQUAL && primitive_type_supports_binary_op::value + && primitive_type_supports_binary_op::value) || + (expr_type == AstComparisonExprType::GREATER_EQUAL && primitive_type_supports_binary_op::value + && primitive_type_supports_binary_op::value)) +> {}; + + // static_cast_offset::get() // On static_cast-able -pair (T, U must both be pointers), // the value is the shift in bytes needed to add to T when converted to U diff --git a/pochivm/common_expr.h b/pochivm/common_expr.h index a7a7dac..5754476 100644 --- a/pochivm/common_expr.h +++ b/pochivm/common_expr.h @@ -114,6 +114,23 @@ class AstLiteralExpr : public AstNodeBase return m_as_uint64_t; } + // Get value of literal as type T. Literal must be of type T + // + template ::value>> + T GetAs() + { + static_assert(AstTypeHelper::is_primitive_type::value, "Attempted to get literal as non-primitive type"); + TestAssert(GetTypeId().IsType() && "Can only call GetAs with original type of literal"); + #define F(type) \ + if constexpr(std::is_same::value) \ + { \ + return m_as_##type; \ + } + FOR_EACH_PRIMITIVE_TYPE + #undef F + TestAssert(false && "internal bug: unsupported primitive type"); + } + private: // Stores the literal value with a union of all possible primitive types // diff --git a/pochivm/for_each_primitive_type.h b/pochivm/for_each_primitive_type.h index 2e3ea28..be7c603 100644 --- a/pochivm/for_each_primitive_type.h +++ b/pochivm/for_each_primitive_type.h @@ -63,6 +63,23 @@ F(char, uint32_t) \ F(char, int64_t) \ F(char, uint64_t) +#define ENUMERATE_ALL_TYPES(fn, a, b, c, d) \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); \ + fn(); + namespace PochiVM { diff --git a/test_sanity_arith_expr.cpp b/test_sanity_arith_expr.cpp index 10bbeb0..86a9fe1 100644 --- a/test_sanity_arith_expr.cpp +++ b/test_sanity_arith_expr.cpp @@ -1,51 +1,97 @@ #include "gtest/gtest.h" +#include +#include #include "pochivm.h" +#include "pochivm/ast_type_helper.h" +#include "pochivm/for_each_primitive_type.h" #include "test_util_helper.h" using namespace PochiVM; namespace { -template -void CompareResults(T /*v1*/, T /*v2*/, retType r1, retType r2) +template +void CompareResults(T v1, U v2, retType r1, retType r2) { - ReleaseAssert(r1 == r2); -} - -template<> -void CompareResults(float v1, float v2, float r1, float r2) -{ - double diff = fabs(static_cast(r1) - static_cast(r2)); - double tol = 1e-6; - if (diff < tol) { return; } - double relDiff = diff / std::max(fabs(static_cast(v1)), fabs(static_cast(v2))); - ReleaseAssert(relDiff < tol); -} - -template<> -void CompareResults(double v1, double v2, double r1, double r2) -{ - double diff = fabs(r1 - r2); - double tol = 1e-12; - if (diff < tol) { return; } - double relDiff = diff / std::max(fabs(v1), fabs(v2)); - ReleaseAssert(relDiff < tol); + if constexpr (std::is_floating_point::value || std::is_floating_point::value) + { + double diff = fabs(static_cast(r1) - static_cast(r2)); + double tol = 1e-6; + if (diff < tol) { return; } + double relDiff = diff / std::max(fabs(static_cast(v1)), fabs(static_cast(v2))); + ReleaseAssert(relDiff < tol); + } else { + ReleaseAssert(r1 == r2); + } } // TODO: currently 'new SimpleJIT()' leaks // -#define GetArithFn(fnName, opName, retType) \ -template \ -std::function fnName() \ -{ \ - using FnPrototype = retType(*)(T, T); \ - auto [fn, val1, val2] = NewFunction("MyFn"); \ - fn.SetBody(Return(val1 opName val2)); \ - ReleaseAssert(thread_pochiVMContext->m_curModule->Validate()); \ - thread_pochiVMContext->m_curModule->PrepareForDebugInterp(); \ - thread_pochiVMContext->m_curModule->PrepareForFastInterp(); \ - auto interpFn = thread_pochiVMContext->m_curModule-> \ +#define GenNoLiteralArithFnTester(fnName, opName, retType) \ +template \ +std::function fnName() \ +{ \ + thread_pochiVMContext->m_curModule = new AstModule("test");\ + using CommonType = typename AstTypeHelper::ArithReturnType::type; \ + using FnPrototype = retType(*)(T, U); \ + auto [fn, val1, val2] = NewFunction("MyFn"); \ + fn.SetBody(Return(val1 opName val2)); \ + ReleaseAssert(thread_pochiVMContext->m_curModule->Validate()); \ + thread_pochiVMContext->m_curModule->PrepareForDebugInterp(); \ + thread_pochiVMContext->m_curModule->PrepareForFastInterp(); \ + auto interpFn = thread_pochiVMContext->m_curModule-> \ + GetDebugInterpGeneratedFunction("MyFn"); \ + ReleaseAssert(interpFn); \ + auto fastinterpFn = thread_pochiVMContext->m_curModule-> \ + GetFastInterpGeneratedFunction("MyFn"); \ + ReleaseAssert(fastinterpFn); \ + \ + thread_pochiVMContext->m_curModule->EmitIR(); \ + thread_pochiVMContext->m_curModule->OptimizeIRIfNotDebugMode(2 /*optLevel*/); \ + \ + SimpleJIT* jit = new SimpleJIT(); \ + jit->SetModule(thread_pochiVMContext->m_curModule); \ + \ + FnPrototype jitFn = jit->GetFunction("MyFn"); \ + auto gold = [](T v1, U v2) -> retType { \ + return static_cast(v1 opName v2); \ + }; \ + std::function compare = [gold, interpFn, jitFn, fastinterpFn] \ + (T v1, U v2) { \ + CompareResults(v1, v2, gold(v1, v2), interpFn(v1,v2)); \ + CompareResults(v1, v2, gold(v1, v2), jitFn(v1,v2)); \ + CompareResults(v1, v2, gold(v1, v2), fastinterpFn(v1,v2)); \ + }; \ + return compare; \ +} \ +template \ +void fnName(T lhs, U rhs) \ +{ \ + AutoThreadPochiVMContext apv; \ + AutoThreadErrorContext arc; \ + AutoThreadLLVMCodegenContext alc; \ + thread_pochiVMContext->m_curModule = new AstModule("test"); \ + static std::function test_fn = fnName(); \ + test_fn(lhs, rhs); \ +} + +#define GenLeftLiteralArithFnTester(fnName, opName, retType) \ +template \ +void fnName(T lhs, U rhs) \ +{ \ + AutoThreadPochiVMContext apv; \ + AutoThreadErrorContext arc; \ + AutoThreadLLVMCodegenContext alc; \ + thread_pochiVMContext->m_curModule = new AstModule("test"); \ + using CommonType = typename AstTypeHelper::ArithReturnType::type; \ + using FnPrototype = retType(*)(U); \ + auto [fn, val1] = NewFunction("MyFn"); \ + fn.SetBody(Return(lhs opName val1)); \ + ReleaseAssert(thread_pochiVMContext->m_curModule->Validate()); \ + thread_pochiVMContext->m_curModule->PrepareForDebugInterp(); \ + thread_pochiVMContext->m_curModule->PrepareForFastInterp(); \ + auto interpFn = thread_pochiVMContext->m_curModule-> \ GetDebugInterpGeneratedFunction("MyFn"); \ ReleaseAssert(interpFn); \ auto fastinterpFn = thread_pochiVMContext->m_curModule-> \ @@ -55,236 +101,287 @@ std::function fnName() thread_pochiVMContext->m_curModule->EmitIR(); \ thread_pochiVMContext->m_curModule->OptimizeIRIfNotDebugMode(2 /*optLevel*/); \ \ - SimpleJIT* jit = new SimpleJIT(); \ - jit->SetModule(thread_pochiVMContext->m_curModule); \ - \ - FnPrototype jitFn = jit->GetFunction("MyFn"); \ - auto gold = [](T v1, T v2) -> retType { \ - return v1 opName v2; \ - }; \ - std::function compare = [gold, interpFn, jitFn, fastinterpFn] \ - (T v1, T v2) { \ - CompareResults(v1, v2, gold(v1, v2), interpFn(v1,v2)); \ - CompareResults(v1, v2, gold(v1, v2), jitFn(v1,v2)); \ - CompareResults(v1, v2, gold(v1, v2), fastinterpFn(v1,v2)); \ + auto gold = [](T v1, U v2) -> retType { \ + return static_cast(v1 opName v2); \ }; \ - return compare; \ + CompareResults(lhs, rhs, gold(lhs, rhs), interpFn(rhs)); \ + CompareResults(lhs, rhs, gold(lhs, rhs), fastinterpFn(rhs)); \ } -GetArithFn(GetAddFn, +, T) -GetArithFn(GetSubFn, -, T) -GetArithFn(GetMulFn, *, T) -GetArithFn(GetDivFn, /, T) -GetArithFn(GetModFn, %, T) -GetArithFn(GetLtFn, <, bool) -GetArithFn(GetLEqFn, <=, bool) -GetArithFn(GetGtFn, >, bool) -GetArithFn(GetGEqFn, >=, bool) -// float point direct comparison is safe in this specific case (we only provide constants) -// +#define GenRightLiteralArithFnTester(fnName, opName, retType) \ +template \ +void fnName(T lhs, U rhs) \ +{ \ + AutoThreadPochiVMContext apv; \ + AutoThreadErrorContext arc; \ + AutoThreadLLVMCodegenContext alc; \ + thread_pochiVMContext->m_curModule = new AstModule("test"); \ + using CommonType = typename AstTypeHelper::ArithReturnType::type; \ + using FnPrototype = retType(*)(T); \ + auto [fn, val1] = NewFunction("MyFn"); \ + fn.SetBody(Return(val1 opName rhs)); \ + ReleaseAssert(thread_pochiVMContext->m_curModule->Validate()); \ + thread_pochiVMContext->m_curModule->PrepareForDebugInterp(); \ + thread_pochiVMContext->m_curModule->PrepareForFastInterp(); \ + auto interpFn = thread_pochiVMContext->m_curModule-> \ + GetDebugInterpGeneratedFunction("MyFn"); \ + ReleaseAssert(interpFn); \ + auto fastinterpFn = thread_pochiVMContext->m_curModule-> \ + GetFastInterpGeneratedFunction("MyFn"); \ + ReleaseAssert(fastinterpFn); \ + \ + thread_pochiVMContext->m_curModule->EmitIR(); \ + thread_pochiVMContext->m_curModule->OptimizeIRIfNotDebugMode(2 /*optLevel*/); \ + \ + auto gold = [](T v1, U v2) -> retType { \ + return static_cast(v1 opName v2); \ + }; \ + CompareResults(lhs, rhs, gold(lhs, rhs), interpFn(lhs)); \ + CompareResults(lhs, rhs, gold(lhs, rhs), fastinterpFn(lhs)); \ +} +#define COMMA , +#define ArithRet typename AstTypeHelper::ArithReturnType::type #pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wfloat-equal" -GetArithFn(GetEqFn, ==, bool) -GetArithFn(GetNEqFn, !=, bool) -#pragma clang diagnostic pop +// CompareResults already accounts for any lost precision between conversions so +// ignore warnings about them +// +#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion" +#pragma clang diagnostic ignored "-Wdouble-promotion" -template -void TestInterestingIntegerParams(std::function()> fnGen, - bool isMulOp, bool isDivOp) -{ - AutoThreadPochiVMContext apv; - AutoThreadErrorContext arc; - AutoThreadLLVMCodegenContext alc; +GenLeftLiteralArithFnTester(TestAddLeftLiteral, +, ArithRet); +GenRightLiteralArithFnTester(TestAddRightLiteral, +, ArithRet); +GenNoLiteralArithFnTester(TestAddNoLiteral, +, ArithRet); - thread_pochiVMContext->m_curModule = new AstModule("test"); +GenLeftLiteralArithFnTester(TestSubLeftLiteral, +, ArithRet); +GenRightLiteralArithFnTester(TestSubRightLiteral, +, ArithRet); +GenNoLiteralArithFnTester(TestSubNoLiteral, +, ArithRet); - std::function fn = fnGen(); - T start = 0; - if (std::is_signed::value) - { - start = static_cast(-50); - if (std::is_same::value) - { - start = static_cast(-10); // don't overflow - } - } - T end = 50; - if (sizeof(T) == 1) - { - end = static_cast(10); // don't overflow - } - T sf; - if (!isMulOp) - { - sf = static_cast(1) << (sizeof(T) * 8 - 8); - } - else - { - if (sizeof(T) == 1) - { - sf = 1; - } - else - { - sf = static_cast(1) << ((sizeof(T) * 8 - 16) / 2); - } - } - for (int mx = 0; mx < 4; mx++) +GenLeftLiteralArithFnTester(TestMulLeftLiteral, *, ArithRet); +GenRightLiteralArithFnTester(TestMulRightLiteral, *, ArithRet); +GenNoLiteralArithFnTester(TestMulNoLiteral, *, ArithRet); + +GenLeftLiteralArithFnTester(TestDivLeftLiteral, /, ArithRet); +GenRightLiteralArithFnTester(TestDivRightLiteral, /, ArithRet); +GenNoLiteralArithFnTester(TestDivNoLiteral, /, ArithRet); + +GenLeftLiteralArithFnTester(TestModLeftLiteral, %, ArithRet); +GenRightLiteralArithFnTester(TestModRightLiteral, %, ArithRet); +GenNoLiteralArithFnTester(TestModNoLiteral, %, ArithRet); + +GenLeftLiteralArithFnTester(TestLTLeftLiteral, <, bool); +GenRightLiteralArithFnTester(TestLTRightLiteral, <, bool); +GenNoLiteralArithFnTester(TestLTNoLiteral, <, bool); + +GenLeftLiteralArithFnTester(TestLEQLeftLiteral, <=, bool); +GenRightLiteralArithFnTester(TestLEQRightLiteral, <=, bool); +GenNoLiteralArithFnTester(TestLEQNoLiteral, <=, bool); + +GenLeftLiteralArithFnTester(TestEQLeftLiteral, ==, bool); +GenRightLiteralArithFnTester(TestEQRightLiteral, ==, bool); +GenNoLiteralArithFnTester(TestEQNoLiteral, ==, bool); + +GenLeftLiteralArithFnTester(TestNEQLeftLiteral, !=, bool); +GenRightLiteralArithFnTester(TestNEQRightLiteral, !=, bool); +GenNoLiteralArithFnTester(TestNEQNoLiteral, !=, bool); + +GenLeftLiteralArithFnTester(TestGEQLeftLiteral, >=, bool); +GenRightLiteralArithFnTester(TestGEQRightLiteral, >=, bool); +GenNoLiteralArithFnTester(TestGEQNoLiteral, >=, bool); + +GenLeftLiteralArithFnTester(TestGTLeftLiteral, >, bool); +GenRightLiteralArithFnTester(TestGTRightLiteral, >, bool); +GenNoLiteralArithFnTester(TestGTNoLiteral, >, bool); + +#pragma clang diagnostic pop +} // anonymous namespace + +// execute each fn in `fns` with lhs and rhs +template +void apply_fns(std::vector> fns, T lhs, U rhs) +{ + for(std::functionfn : fns) { - for (T v1 = start; v1 <= end; v1++) - { - for (T v2 = start; v2 <= end; v2++) - { - if (isDivOp && v2 == 0) continue; - T x1 = v1; - T x2 = v2; - if (mx % 2 == 0) x1 *= sf; - if (mx / 2 == 0) x2 *= sf; - fn(x1, x2); - } - } + fn(lhs, rhs); } } - -template -void TestInterestingFloatParams(std::function()> fnGen, bool isDivOp) +template +void TestSignedAdditionAndSubtractionAndComparison() { - AutoThreadPochiVMContext apv; - AutoThreadErrorContext arc; - AutoThreadLLVMCodegenContext alc; - - thread_pochiVMContext->m_curModule = new AstModule("test"); + T minT = std::numeric_limits::min() + 1; + T maxT = std::numeric_limits::max() - 1; + U minU = std::numeric_limits::min() + 1; + U maxU = std::numeric_limits::max() - 1; + std::vector> testers; + testers.push_back([](T lhs, U rhs){TestAddNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestAddLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestAddRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGTNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGTLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGTRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestNEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestNEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestNEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLTNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLTLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLTRightLiteral(lhs, rhs);}); + // Test signed promotions. All other cases have been covered in the unsigned tests + // so don't iterate over possible values to save time + apply_fns(testers, maxT / 2, maxU / 2); + apply_fns(testers, minT / 2, minU / 2); +} - std::function fn = fnGen(); - const static T values[21] = { - static_cast(0), - static_cast(-5), - static_cast(-4.5), - static_cast(-4), - static_cast(-3.5), - static_cast(-3), - static_cast(-2.5), - static_cast(-2), - static_cast(-1.5), - static_cast(-1), - static_cast(-0.5), - static_cast(0.5), - static_cast(1), - static_cast(1.5), - static_cast(2), - static_cast(2.5), - static_cast(3), - static_cast(3.5), - static_cast(4), - static_cast(4.5), - static_cast(5) - }; +template +void TestSignedMultiplicationModAndDivision() +{ + T minT = std::numeric_limits::min() + 1; + T maxT = std::numeric_limits::max() - 1; + U minU = std::numeric_limits::min() + 1; + U maxU = std::numeric_limits::max() - 1; + std::vector> testers; + testers.push_back([](T lhs, U rhs){TestMulNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestModNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestModLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestModRightLiteral(lhs, rhs);}); + // Test signed promotions. All other cases have been covered in the unsigned tests + // so don't iterate over possible values to save time + apply_fns(testers, maxT / 2, static_cast(2)); + apply_fns(testers, minT / 2, static_cast(2)); + apply_fns(testers, static_cast(2), maxU / 2); + apply_fns(testers, static_cast(2), minU / 2); +} - for (int i = 0; i < 21; i++) +template +void TestUnsignedAdditionAndSubtractionAndComparison() +{ + T maxT = std::numeric_limits::max(); + U maxU = std::numeric_limits::max(); + std::vector> testers; + testers.push_back([](T lhs, U rhs){TestAddNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestAddLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestAddRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGTNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGTLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGTRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestGEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestNEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestNEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestNEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLEQNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLEQLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLEQRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLTNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLTLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestLTRightLiteral(lhs, rhs);}); + unsigned int steps = 8; + for(T t = 0; t < maxT - maxT / steps; t += maxT / steps) { - for (int j = 0; j < 21; j++) + for(U u = 0; u < maxU - maxU / steps; u += maxU / steps) { - if (isDivOp && j == 0) continue; - fn(values[i], values[j]); + apply_fns(testers, t, u); } } } -void TestBoolParams(std::function()> fnGen) +template +void TestUnsignedMultiplicationModAndDivision() { - AutoThreadPochiVMContext apv; - AutoThreadErrorContext arc; - AutoThreadLLVMCodegenContext alc; - - thread_pochiVMContext->m_curModule = new AstModule("test"); - - std::function fn = fnGen(); - - for (bool v1 : {false, true}) + T maxT = std::numeric_limits::max(); + U maxU = std::numeric_limits::max(); + std::vector> testers; + testers.push_back([](T lhs, U rhs){TestMulNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestModNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestModLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestModRightLiteral(lhs, rhs);}); + unsigned int steps = 8; + for(T t = 1; t <= steps + 1/* ensures overflow */; ++t) { - for (bool v2 : {false, true}) - { - fn(v1, v2); - } + U u = maxU / steps; + apply_fns(testers, t, u); + } + for(U u = 1; u <= steps + 1; ++u) + { + T t = maxT / steps; + apply_fns(testers, t, u); } } -} // anonymous namespace +// Ensures that everything is promoted to float when necessary +// +template +void TestFloatingPromotions() +{ + std::vector> testers; + testers.push_back([](T lhs, U rhs){TestAddNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestAddLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestAddRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestSubRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestMulRightLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivNoLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivLeftLiteral(lhs, rhs);}); + testers.push_back([](T lhs, U rhs){TestDivRightLiteral(lhs, rhs);}); + T t = static_cast(100) / static_cast(11); + U u = static_cast(100) / static_cast(11); + apply_fns(testers, t, u); +} +void TestBoolParams() +{ + std::vector> testers; + testers.push_back([](bool lhs, bool rhs){TestEQNoLiteral(lhs, rhs);}); + testers.push_back([](bool lhs, bool rhs){TestEQLeftLiteral(lhs, rhs);}); + testers.push_back([](bool lhs, bool rhs){TestEQRightLiteral(lhs, rhs);}); + testers.push_back([](bool lhs, bool rhs){TestNEQNoLiteral(lhs, rhs);}); + testers.push_back([](bool lhs, bool rhs){TestNEQLeftLiteral(lhs, rhs);}); + testers.push_back([](bool lhs, bool rhs){TestNEQRightLiteral(lhs, rhs);}); + apply_fns(testers, true, true); + apply_fns(testers, true, false); + apply_fns(testers, false, true); + apply_fns(testers, false, false); +} TEST(Sanity, ArithAndCompareExpr) { - // Test int types - // -#define F(type) TestInterestingIntegerParams(GetAddFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetSubFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetMulFn, true /*isMulOp*/, false /*isDivOp*/); + ENUMERATE_ALL_TYPES(TestSignedAdditionAndSubtractionAndComparison, int8_t, int16_t, int32_t, int64_t) + ENUMERATE_ALL_TYPES(TestUnsignedAdditionAndSubtractionAndComparison, uint8_t, uint16_t, uint32_t, uint64_t) + ENUMERATE_ALL_TYPES(TestSignedMultiplicationModAndDivision, int8_t, int16_t, int32_t, int64_t) + ENUMERATE_ALL_TYPES(TestUnsignedMultiplicationModAndDivision, uint8_t, uint16_t, uint32_t, uint64_t) + #define F(type) TestFloatingPromotions(); TestFloatingPromotions(); \ + TestFloatingPromotions(); TestFloatingPromotions(); FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetModFn, false /*isMulOp*/, true /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetDivFn, false /*isMulOp*/, true /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetEqFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetNEqFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetLtFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetLEqFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetGtFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F -#define F(type) TestInterestingIntegerParams(GetGEqFn, false /*isMulOp*/, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_INT_TYPE_EXCEPT_BOOL -#undef F - - // Test float types - // -#define F(type) TestInterestingFloatParams(GetAddFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetSubFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetMulFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetDivFn, true /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetEqFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetNEqFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetLtFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetLEqFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetGtFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F -#define F(type) TestInterestingFloatParams(GetGEqFn, false /*isDivOp*/); - FOR_EACH_PRIMITIVE_FLOAT_TYPE -#undef F - - // Test bool types, only == and != are supported - // - TestBoolParams(GetEqFn); - TestBoolParams(GetNEqFn); + #undef F + TestBoolParams(); }