Skip to content

Commit

Permalink
Prevent recursion in _eps (#207)
Browse files Browse the repository at this point in the history
* Revert "Reactant: add extension to prevent stackoverflow (#206)"

This reverts commit 12b7f31.

* Change eps to not be recursive

* Update src/utils.jl

Co-authored-by: Michael Abbott <[email protected]>

* Update src/utils.jl

Co-authored-by: Michael Abbott <[email protected]>

---------

Co-authored-by: Michael Abbott <[email protected]>
  • Loading branch information
wsmoses and mcabbott authored Jan 5, 2025
1 parent 12b7f31 commit 5ce09af
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[extensions]
OptimisersAdaptExt = ["Adapt"]
OptimisersEnzymeCoreExt = "EnzymeCore"
OptimisersReactantExt = "Reactant"

[compat]
Adapt = "4"
Expand Down
8 changes: 0 additions & 8 deletions ext/OptimisersReactantExt.jl

This file was deleted.

14 changes: 10 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ end

ofeltype(x, y) = convert(float(eltype(x)), y)

_eps(T::Type{<:AbstractFloat}, e) = T(e)
# catch complex and integers
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e)
# avoid small e being rounded to zero
"""
_eps(Type{T}, val)
Mostly this produces `real(T)(val)`, so that `_eps(Float32, 1e-8) === 1f-8` will
convert the Float64 parameter epsilon to work nicely with Float32 parameter arrays.
But for Float16, it imposes a minimum of `Float16(1e-7)`, unless `val==0`.
This is basically a hack to increase the default epsilon, to help many optimisers avoid NaN.
"""
_eps(T::Type{<:Number}, e) = real(float(T))(e)
_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e))

0 comments on commit 5ce09af

Please sign in to comment.