diff --git a/main/src/io/github/iltotore/iron/constraint/numeric.scala b/main/src/io/github/iltotore/iron/constraint/numeric.scala index d85c1a2..5b90899 100644 --- a/main/src/io/github/iltotore/iron/constraint/numeric.scala +++ b/main/src/io/github/iltotore/iron/constraint/numeric.scala @@ -167,14 +167,30 @@ object numeric: override inline def test(inline value: Double): Boolean = value > doubleValue[V] inline given bigDecimalDouble[V <: NumConstant]: GreaterConstraint[BigDecimal, V] with - override inline def test(inline value: BigDecimal): Boolean = value > BigDecimal(doubleValue[V]) + override inline def test(inline value: BigDecimal): Boolean = ${checkBigDecimalDouble('value, '{doubleValue[V]})} inline given bigDecimalLong[V <: Int | Long]: GreaterConstraint[BigDecimal, V] with - override inline def test(inline value: BigDecimal): Boolean = value > BigDecimal(longValue[V]) + override inline def test(inline value: BigDecimal): Boolean = ${checkBigDecimalLong('value, '{longValue[V]})} inline given [V <: Int | Long]: GreaterConstraint[BigInt, V] with override inline def test(inline value: BigInt): Boolean = ${checkBigInt('value, '{longValue[V]})} + private def checkBigDecimalDouble(expr: Expr[BigDecimal], thanExpr: Expr[Double])(using Quotes): Expr[Boolean] = + val rflUtil = reflectUtil + import rflUtil.* + + (expr.decode, thanExpr.decode) match + case (Right(value), Right(than)) => Expr(value > BigDecimal(than)) + case _ => '{$expr > BigDecimal($thanExpr)} + + private def checkBigDecimalLong(expr: Expr[BigDecimal], thanExpr: Expr[Long])(using Quotes): Expr[Boolean] = + val rflUtil = reflectUtil + import rflUtil.* + + (expr.decode, thanExpr.decode) match + case (Right(value), Right(than)) => Expr(value > BigDecimal(than)) + case _ => '{$expr > BigDecimal($thanExpr)} + private def checkBigInt(expr: Expr[BigInt], thanExpr: Expr[Long])(using Quotes): Expr[Boolean] = val rflUtil = reflectUtil import rflUtil.* @@ -210,14 +226,30 @@ object numeric: override inline def test(inline value: Double): Boolean = value < doubleValue[V] inline given bigDecimalDouble[V <: NumConstant]: LessConstraint[BigDecimal, V] with - override inline def test(inline value: BigDecimal): Boolean = value < BigDecimal(doubleValue[V]) + override inline def test(inline value: BigDecimal): Boolean = ${checkBigDecimalDouble('value, '{doubleValue[V]})} inline given bigDecimalLong[V <: Int | Long]: LessConstraint[BigDecimal, V] with - override inline def test(inline value: BigDecimal): Boolean = value < BigDecimal(longValue[V]) + override inline def test(inline value: BigDecimal): Boolean = ${checkBigDecimalLong('value, '{longValue[V]})} inline given [V <: Int | Long]: LessConstraint[BigInt, V] with override inline def test(inline value: BigInt): Boolean = ${checkBigInt('value, '{longValue[V]})} + private def checkBigDecimalDouble(expr: Expr[BigDecimal], thanExpr: Expr[Double])(using Quotes): Expr[Boolean] = + val rflUtil = reflectUtil + import rflUtil.* + + (expr.decode, thanExpr.decode) match + case (Right(value), Right(than)) => Expr(value < BigDecimal(than)) + case _ => '{$expr < BigDecimal($thanExpr)} + + private def checkBigDecimalLong(expr: Expr[BigDecimal], thanExpr: Expr[Long])(using Quotes): Expr[Boolean] = + val rflUtil = reflectUtil + import rflUtil.* + + (expr.decode, thanExpr.decode) match + case (Right(value), Right(than)) => Expr(value < BigDecimal(than)) + case _ => '{$expr < BigDecimal($thanExpr)} + private def checkBigInt(expr: Expr[BigInt], thanExpr: Expr[Long])(using Quotes): Expr[Boolean] = val rflUtil = reflectUtil import rflUtil.* @@ -252,13 +284,21 @@ object numeric: inline given [V <: NumConstant]: MultipleConstraint[Double, V] with override inline def test(inline value: Double): Boolean = value % doubleValue[V] == 0 + inline given [V <: NumConstant]: MultipleConstraint[BigDecimal, V] with + + override inline def test(inline value: BigDecimal): Boolean = ${checkBigDecimal('value, '{doubleValue[V]})} + inline given [V <: Int | Long]: MultipleConstraint[BigInt, V] with override inline def test(inline value: BigInt): Boolean = ${checkBigInt('value, '{longValue[V]})} - inline given [V <: NumConstant]: MultipleConstraint[BigDecimal, V] with + private def checkBigDecimal(expr: Expr[BigDecimal], thanExpr: Expr[Double])(using Quotes): Expr[Boolean] = + val rflUtil = reflectUtil + import rflUtil.* - override inline def test(inline value: BigDecimal): Boolean = value % BigDecimal(doubleValue[V]) == 0 + (expr.decode, thanExpr.decode) match + case (Right(value), Right(than)) => Expr(value % BigDecimal(than) == 0) + case _ => '{$expr % BigDecimal($thanExpr) == 0} private def checkBigInt(expr: Expr[BigInt], thanExpr: Expr[Long])(using Quotes): Expr[Boolean] = val rflUtil = reflectUtil @@ -286,11 +326,19 @@ object numeric: inline given [V <: NumConstant]: DivideConstraint[Double, V] with override inline def test(inline value: Double): Boolean = doubleValue[V] % value == 0 + inline given [V <: NumConstant]: DivideConstraint[BigDecimal, V] with + override inline def test(inline value: BigDecimal): Boolean = ${checkBigDecimal('value, '{doubleValue[V]})} + inline given [V <: Int | Long]: DivideConstraint[BigInt, V] with override inline def test(inline value: BigInt): Boolean = ${checkBigInt('value, '{longValue[V]})} - inline given [V <: NumConstant]: DivideConstraint[BigDecimal, V] with - override inline def test(inline value: BigDecimal): Boolean = BigDecimal(doubleValue[V]) % value == 0 + private def checkBigDecimal(expr: Expr[BigDecimal], thanExpr: Expr[Double])(using Quotes): Expr[Boolean] = + val rflUtil = reflectUtil + import rflUtil.* + + (expr.decode, thanExpr.decode) match + case (Right(value), Right(than)) => Expr(BigDecimal(than) % value == 0) + case _ => '{BigDecimal($thanExpr) % $expr == 0} private def checkBigInt(expr: Expr[BigInt], thanExpr: Expr[Long])(using Quotes): Expr[Boolean] = val rflUtil = reflectUtil diff --git a/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala b/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala index 9061bed..8975e17 100644 --- a/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala +++ b/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala @@ -181,9 +181,10 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): object ExprDecoder: private val enhancedDecoders: Map[TypeRepr, (Term, Map[String, ?]) => Either[DecodingFailure, ?]] = Map( - TypeRepr.of[Boolean] -> decodeBoolean, - TypeRepr.of[BigInt] -> decodeBigInt, - TypeRepr.of[String] -> decodeString + TypeRepr.of[Boolean] -> decodeBoolean, + TypeRepr.of[BigDecimal] -> decodeBigDecimal, + TypeRepr.of[BigInt] -> decodeBigInt, + TypeRepr.of[String] -> decodeString ) /** @@ -327,6 +328,13 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): case _ => Left(DecodingFailure.Unknown) + /** + * Decode a [[BigInt]] term using only [[BigInt]]-specific cases. + * + * @param term the term to decode + * @param definitions the decoded definitions in scope + * @return the value of the given term found at compile time or a [[DecodingFailure]] + */ def decodeBigInt(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigInt] = term match case Apply(Select(Ident("BigInt"), "apply"), List(value)) => @@ -334,3 +342,20 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): else if value.tpe <:< TypeRepr.of[Long] then decodeTerm[Long](value, definitions).map(BigInt.apply) else Left(DecodingFailure.Unknown) case _ => Left(DecodingFailure.Unknown) + + /** + * Decode a [[BigDecimal]] term using only [[BigDecimal]]-specific cases. + * + * @param term the term to decode + * @param definitions the decoded definitions in scope + * @return the value of the given term found at compile time or a [[DecodingFailure]] + */ + def decodeBigDecimal(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigDecimal] = + term match + case Apply(Select(Ident("BigDecimal"), "apply"), List(value)) => + if value.tpe <:< TypeRepr.of[Int] then decodeTerm[Int](value, definitions).map(BigDecimal.apply) + else if value.tpe <:< TypeRepr.of[Long] then decodeTerm[Long](value, definitions).map(BigDecimal.apply) + else if value.tpe <:< TypeRepr.of[Double] then decodeTerm[Double](value, definitions).map(BigDecimal.apply) + else if value.tpe <:< TypeRepr.of[BigInt] then decodeTerm[BigInt](value, definitions).map(BigDecimal.apply) + else Left(DecodingFailure.Unknown) + case _ => Left(DecodingFailure.Unknown)