From f313cfcf1e329aa0e2b8f68c3addd31c6bfce99c Mon Sep 17 00:00:00 2001 From: Clyybber Date: Sat, 8 Feb 2025 18:15:25 +0100 Subject: [PATCH] Don't use float equality for AST comparisons (#1238) ## Summary Float equality lacks the substitution nor reflexivity properties usually expected from an equality operator, so it's not correct to use float equality in AST comparisons. This PR changes it so that they are compared for bit equality. ## Details * `0.0` and `-0.0` are not being considered equal by the compiler anymore, this affects: - static parameters in generic types and procedures, see `tests/statictypes/tstatictypes.nim` - default arguments for forward declarations, see `tests/errmsgs/tforwarddecl_defaultparam.nim` - ```macros.`==`(a, b: NimNode)```, see `tests/lang_callable/macros/tmacros_various.nim` - term-rewriting macros, see `tests/lang_experimental/trmacros/trmacros_various2.nim` * `trees.exprStructuralEquivalentStrictSym` and it's only usage in `sem/semfoldnim` have been removed --------- Co-authored-by: zerbina <100542850+zerbina@users.noreply.github.com> --- compiler/ast/trees.nim | 28 ++++++------------- compiler/sem/guards.nim | 3 +- compiler/sem/patterns.nim | 4 +-- compiler/sem/semfold.nim | 11 ++++++-- compiler/vm/vmgen.nim | 23 +-------------- tests/errmsgs/tforwarddecl_defaultparam.nim | 9 ++++++ .../generics/tgenerics_issues.nim | 2 +- .../lang_callable/macros/tmacros_various.nim | 13 +++++++-- .../trmacros/trmacros_various2.nim | 13 +++++++++ tests/statictypes/tstatictypes.nim | 14 ++++++++-- 10 files changed, 68 insertions(+), 52 deletions(-) create mode 100644 tests/errmsgs/tforwarddecl_defaultparam.nim diff --git a/compiler/ast/trees.nim b/compiler/ast/trees.nim index 24473ece706..30c83bef68e 100644 --- a/compiler/ast/trees.nim +++ b/compiler/ast/trees.nim @@ -40,12 +40,16 @@ proc cyclicTree*(n: PNode): bool = var visited: seq[PNode] = @[] cyclicTreeAux(n, visited) -proc sameFloatIgnoreNan(a, b: BiggestFloat): bool {.inline.} = - ## ignores NaN semantics, but ensures 0.0 == -0.0, see #13730 - cast[uint64](a) == cast[uint64](b) or a == b +template cmpFloatRep*(a, b: BiggestFloat): bool = + ## Compares the bit-representation of floats `a` and `b` + # Special handling for floats, so that floats that have the same + # value but different bit representations are treated as different constants + # Compared to float equality, this does not lack the substitution and + # reflexivity property, which the compiler relies on for correctness. + cast[uint64](a) == cast[uint64](b) template makeTreeEquivalenceProc*( - name, relaxedKindCheck, symCheck, floatCheck, typeCheck, commentCheck) {.dirty.} = + name, relaxedKindCheck, symCheck, typeCheck, commentCheck) {.dirty.} = ## Defines a tree equivalence checking procedure. ## This skeleton is shared between all recursive ## `PNode` equivalence checks in the compiler code base @@ -61,10 +65,7 @@ template makeTreeEquivalenceProc*( of nkSym: result = symCheck of nkIdent: result = a.ident.id == b.ident.id of nkIntLiterals: result = a.intVal == b.intVal - of nkFloatLiterals: result = floatCheck - # XXX: Using float equality, even if partially tamed through - # sameFloatIgnoreNan, causes inconsistencies due to it - # lacking the substition and reflexivity property. + of nkFloatLiterals: result = cmpFloatRep(a.floatVal, b.floatVal) of nkStrLiterals: result = a.strVal == b.strVal of nkType: result = typeCheck of nkCommentStmt: result = commentCheck @@ -78,25 +79,14 @@ template makeTreeEquivalenceProc*( makeTreeEquivalenceProc(exprStructuralEquivalent, relaxedKindCheck = false, symCheck = a.sym.name.id == b.sym.name.id, # same symbol as string is enough - floatCheck = sameFloatIgnoreNan(a.floatVal, b.floatVal), typeCheck = true, commentCheck = true ) export exprStructuralEquivalent -makeTreeEquivalenceProc(exprStructuralEquivalentStrictSym, - relaxedKindCheck = false, - symCheck = a.sym == b.sym, - floatCheck = sameFloatIgnoreNan(a.floatVal, b.floatVal), - typeCheck = true, - commentCheck = true -) -export exprStructuralEquivalentStrictSym - makeTreeEquivalenceProc(exprStructuralEquivalentStrictSymAndComm, relaxedKindCheck = false, symCheck = a.sym == b.sym, - floatCheck = sameFloatIgnoreNan(a.floatVal, b.floatVal), typeCheck = a.typ == b.typ, commentCheck = a.comment == b.comment ) diff --git a/compiler/sem/guards.nim b/compiler/sem/guards.nim index bca5b672e4d..aa4df7b0458 100644 --- a/compiler/sem/guards.nim +++ b/compiler/sem/guards.nim @@ -454,10 +454,11 @@ proc sameOpr(a, b: PSym): bool = else: result = a == b makeTreeEquivalenceProc(sameTree, + # XXX: This completely ignores that expressions might + # not be pure/deterministic. relaxedKindCheck = false, symCheck = sameOpr(a.sym, b.sym) or (a.sym.magic != mNone and a.sym.magic == b.sym.magic), - floatCheck = a.floatVal == b.floatVal, typeCheck = a.typ == b.typ, commentCheck = true # ignore comments ) diff --git a/compiler/sem/patterns.nim b/compiler/sem/patterns.nim index dead41b9ab9..27e28c4fbb8 100644 --- a/compiler/sem/patterns.nim +++ b/compiler/sem/patterns.nim @@ -63,11 +63,9 @@ proc sameKinds(a, b: PNode): bool {.inline.} = makeTreeEquivalenceProc(sameTrees, relaxedKindCheck = sameKinds(a, b), symCheck = a.sym == b.sym, - floatCheck = a.floatVal == b.floatVal, typeCheck = sameTypeOrNil(a.typ, b.typ), commentCheck = true # Ignore comments ) -export sameTrees proc inSymChoice(sc, x: PNode): bool = if sc.kind == nkClosedSymChoice: @@ -177,7 +175,7 @@ proc matches(c: PPatternContext, p, n: PNode): bool = of nkSym: result = p.sym == n.sym of nkIdent: result = p.ident.id == n.ident.id of nkIntLiterals: result = p.intVal == n.intVal - of nkFloatLiterals: result = p.floatVal == n.floatVal + of nkFloatLiterals: result = cmpFloatRep(p.floatVal, n.floatVal) of nkStrLiterals: result = p.strVal == n.strVal of nkEmpty, nkNilLit, nkType, nkCommentStmt: result = true # Ignore comments diff --git a/compiler/sem/semfold.nim b/compiler/sem/semfold.nim index c235366cc08..3d3908a9a5b 100644 --- a/compiler/sem/semfold.nim +++ b/compiler/sem/semfold.nim @@ -31,6 +31,7 @@ import ], compiler/front/[ options, + msgs, ], compiler/utils/[ platform, @@ -378,8 +379,14 @@ proc evalOp*(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph): result = copyTree(a) result.typ = n.typ of mEqProc: - result = newIntNodeT(toInt128(ord( - exprStructuralEquivalentStrictSym(a, b))), n, idgen, g) + g.config.internalAssert(a.kind in {nkSym, nkNilLit} and + b.kind in {nkSym, nkNilLit}, + n.info, "mEqProc: invalid AST") + let isEqual = + if a.kind != b.kind: false + elif a.kind == nkSym: a.sym == b.sym # b.kind == nkSym + else: true # a.kind == b.kind == nkNilLit + result = newIntNodeT(toInt128(ord(isEqual)), n, idgen, g) else: discard proc getConstIfExpr(c: PSym, n: PNode; idgen: IdGenerator; g: ModuleGraph): PNode = diff --git a/compiler/vm/vmgen.nim b/compiler/vm/vmgen.nim index 83d669e38dc..2c75cb3dd3e 100644 --- a/compiler/vm/vmgen.nim +++ b/compiler/vm/vmgen.nim @@ -722,27 +722,6 @@ proc rawGenLiteral(c: var TCtx, val: sink VmConstant): int = internalAssert c.config, result < regBxMax, "Too many constants used" -template cmpFloatRep(a, b: BiggestFloat): bool = - ## Compares the bit-representation of floats `a` and `b` - # Special handling for floats, so that floats that have the same - # value but different bit representations are treated as different constants - cast[uint64](a) == cast[uint64](b) - # refs bug #16469 - # if we wanted to only distinguish 0.0 vs -0.0: - # if a.floatVal == 0.0: result = cast[uint64](a.floatVal) == cast[uint64](b.floatVal) - # else: result = a.floatVal == b.floatVal - -# Compares two trees for structural equality, also taking the type of -# ``nkType`` nodes into account. This procedure is used to prevent the same -# AST from being added as a node constant more than once -makeTreeEquivalenceProc(cmpNodeCnst, - relaxedKindCheck = false, - symCheck = a.sym == b.sym, - floatCheck = cmpFloatRep(a.floatVal, b.floatVal), - typeCheck = a.typ == b.typ, - commentCheck = a.comment == b.comment -) - template makeCnstFunc(name, vType, aKind, valName, cmp) {.dirty.} = proc name(c: var TCtx, val: vType): int = for (i, cnst) in c.constants.pairs(): @@ -752,7 +731,7 @@ template makeCnstFunc(name, vType, aKind, valName, cmp) {.dirty.} = c.rawGenLiteral: VmConstant(kind: aKind, valName: val) -makeCnstFunc(toNodeCnst, PNode, cnstNode, node, cmpNodeCnst) +makeCnstFunc(toNodeCnst, PNode, cnstNode, node, exprStructuralEquivalentStrictSymAndComm) makeCnstFunc(toIntCnst, BiggestInt, cnstInt, intVal, `==`) diff --git a/tests/errmsgs/tforwarddecl_defaultparam.nim b/tests/errmsgs/tforwarddecl_defaultparam.nim new file mode 100644 index 00000000000..17bf54a0e85 --- /dev/null +++ b/tests/errmsgs/tforwarddecl_defaultparam.nim @@ -0,0 +1,9 @@ +discard """ +errormsg: "overloaded 'reciprocal' leads to ambiguous calls" +line: 9 +""" + +# Differing float literal default args must prevent forward declaration +# and the compiler must not compare them via float equality +proc reciprocal(f: float = 0.0): float +proc reciprocal(f: float = -0.0): float = 1 / f diff --git a/tests/lang_callable/generics/tgenerics_issues.nim b/tests/lang_callable/generics/tgenerics_issues.nim index 010ade89b1f..d84039b0772 100644 --- a/tests/lang_callable/generics/tgenerics_issues.nim +++ b/tests/lang_callable/generics/tgenerics_issues.nim @@ -1084,4 +1084,4 @@ block typed_macro_in_generic_object_when: var o1 = Object[0]() doAssert not compiles(o1.val) var o2 = Object[1](val: 2) - doAssert o2.val == 2 \ No newline at end of file + doAssert o2.val == 2 diff --git a/tests/lang_callable/macros/tmacros_various.nim b/tests/lang_callable/macros/tmacros_various.nim index 39802860453..118345ed74e 100644 --- a/tests/lang_callable/macros/tmacros_various.nim +++ b/tests/lang_callable/macros/tmacros_various.nim @@ -14,6 +14,8 @@ CommentStmt "comment 1" CommentStmt "comment 2" false false +false +true ''' output: ''' @@ -334,8 +336,8 @@ block: # bug #15118 flop("b") block: - # Ensure nkCommentStmt equality is not ignored when vmgen.cmpNodeCnst - # is used to deduplicate NimNode constants, so that `CommentStmt "comment 2"` + # Ensure nkCommentStmt equality is not ignored when vmgen.toNodeCnst + # deduplicates NimNode constants, so that `CommentStmt "comment 2"` # is not counted as a duplicate of `CommentStmt "comment 1"` and # incorrectly optimized to point at the `Comment "comment 1"` node @@ -386,3 +388,10 @@ block: except E: discard ) + +block: + # Ensure float equality semantics are not used when comparing AST for equality + + static: + echo newLit(0.0) == newLit(-0.0) # false + echo newLit(NaN) == newLit(NaN) # true diff --git a/tests/lang_experimental/trmacros/trmacros_various2.nim b/tests/lang_experimental/trmacros/trmacros_various2.nim index 44941a02dff..e31afe25188 100644 --- a/tests/lang_experimental/trmacros/trmacros_various2.nim +++ b/tests/lang_experimental/trmacros/trmacros_various2.nim @@ -8,6 +8,8 @@ lo my awesome concat 1 TRM +10000000000.0 +-10000000000.0 ''' """ @@ -99,3 +101,14 @@ echo u * 3'u # 1 template dontAppendE{`&`(s, 'E')}(s: string): string = s var s = "T" echo s & 'E' & 'R' & 'M' + +# Floats must not be matched with float equality semantics +template capDivPos0{`/`(f, 0.0)}(f: float): float = + 10000000000.float + +template capDivNeg0{`/`(f, -0.0)}(f: float): float = + -10000000000.float + +let f = 1.0 +echo f / 0.0 # 10000000000.0 +echo f / -0.0 # -10000000000.0 diff --git a/tests/statictypes/tstatictypes.nim b/tests/statictypes/tstatictypes.nim index 54728cb969a..a1f3c5ac6a6 100644 --- a/tests/statictypes/tstatictypes.nim +++ b/tests/statictypes/tstatictypes.nim @@ -103,7 +103,11 @@ when true: block: # issue #13730 type Foo[T: static[float]] = object - doAssert Foo[0.0] is Foo[-0.0] + # It should not actually be considered the same type as + # float equality does not have the substition property, + # For example: 1 / 0.0 = Inf != -Inf = 1 / -0.0 + # even though 0.0 == -0.0 according to float semantics + doAssert Foo[0.0] isnot Foo[-0.0] when true: type @@ -411,4 +415,10 @@ block coercion_to_static_type: result = 2.1 # the call must be fully evaluated at compile-time - doAssert static[int](get()) == 1 \ No newline at end of file + doAssert static[int](get()) == 1 + +proc reciprocal(f: static float): float = + 1 / f + +doAssert reciprocal(-0.0) == -Inf +doAssert reciprocal(0.0) == Inf