Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ChainRulesCore rules #3

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

add ChainRulesCore rules #3

wants to merge 9 commits into from

Conversation

mileslucas
Copy link
Member

This PR adds analytical gradients using ChainRulesCore.jl

@codecov
Copy link

codecov bot commented Sep 11, 2021

Codecov Report

Merging #3 (eb7de41) into main (1399529) will not change coverage.
The diff coverage is n/a.

❗ Current head eb7de41 differs from pull request most recent head d46340a. Consider uploading reports for the commit d46340a to get more accurate results
Impacted file tree graph

@@           Coverage Diff           @@
##             main       #3   +/-   ##
=======================================
  Coverage   98.80%   98.80%           
=======================================
  Files           6        6           
  Lines          84       84           
=======================================
  Hits           83       83           
  Misses          1        1           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1399529...d46340a. Read the comment docs.

@mileslucas
Copy link
Member Author

I don't understand why the chain rule tests are failing. Let's look at the isotropic Gaussian PSF as an example

Here is the definition of the gradient

# isotropic
function fgrad(g::Gaussian, point::AbstractVector)
f = g(point)
xdiff = first(point) - first(g.pos)
ydiff = last(point) - last(g.pos)
dfdpos = -2 * GAUSS_PRE * f / g.fwhm^2 .* SA[xdiff, ydiff]
dfdfwhm = -2 * GAUSS_PRE * f * (xdiff^2 + ydiff^2) / g.fwhm^3
dfdamp = f / g.amp
return f, dfdpos, dfdfwhm, dfdamp
end

which I wrote out by hand and can be verified with this derivation http://umdberg.pbworks.com/w/page/88516931/Example%3A%20Gradient%20of%20a%20Gaussian

here are the chain rules

function frule((Δpsf, Δp), g::Gaussian, point::AbstractVector)
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
Δf = dot(dfdpos, Δpsf.pos) + dot(dfdfwhm, Δpsf.fwhm) + dfda * Δpsf.amp
Δf -= dot(dfdpos, Δp)
return f, Δf
end
function rrule(g::G, point::AbstractVector) where {G<:Gaussian}
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
function Gaussian_pullback(Δf)
∂pos = dfdpos .* Δf
∂fwhm = dfdfwhm .* Δf
∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=ZeroTangent())
∂pos = dfdpos .* -Δf
return ∂g, ∂pos
end
return f, Gaussian_pullback
end

using them works as intended-

using ChainRulescore, PSFModels
psf = PSFModels.Gaussian(fwhm=10)
point = [1, 2]
f, pullback = rrule(psf, point)
Δpsf, Δpoint = pullback(1.0)
f2, Δf = frule((Δpsf, Δpoint), psf, point)

# output
(0.8705505632961241, 0.7817442466933209)

but using test_frule and test_rrule consistently fails

@testset "gradients" begin
# have to make sure PSFs are all floating point so tangents don't have type issues
psf_iso = Gaussian(fwhm=10.0, pos=zeros(2))
psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent())
point = Float64[1, 2]
test_frule(psf_iso psf_tang, point)
test_rrule(psf_iso psf_tang, point)
psf_diag = Gaussian(fwhm=Float64[10, 8], pos=zeros(2))
psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent())
test_frule(psf_diag psf_tang, point)
test_rrule(psf_diag psf_tang, point)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant